Implement sorting and pagination for octavia

Use glance sorting and pagination from inside the SQLAlchemy query
to handle the sorting and pagination for octavia.

Change-Id: I5489c5c89691b8871e32caf3f85ab1978bc3618c
Co-Authored-By: Adam Harwell <flux.adam@gmail.com>
Co-Authored-By: Lubosz "diltram" Kosnik <lubosz.kosnik@intel.com>
Closes-Bug: #1596628
Closes-Bug: #1596625
This commit is contained in:
Carlos D. Garza 2016-10-04 20:58:01 -05:00 committed by Adam Harwell
parent fb0da76c27
commit 9bfa58af9f
52 changed files with 1802 additions and 137 deletions

View File

@ -16,6 +16,12 @@
# The default value is the hostname of the host machine. # The default value is the hostname of the host machine.
# host = # host =
# Base URI for the API for use in pagination links.
# This will be autodetected from the request if not overridden here.
# Example:
# api_base_uri = http://localhost:9876
# api_base_uri =
# AMQP Transport URL # AMQP Transport URL
# For Single Host, specify one full transport URL: # For Single Host, specify one full transport URL:
# transport_url = rabbit://<user>:<pass>@127.0.0.1:5672/<vhost> # transport_url = rabbit://<user>:<pass>@127.0.0.1:5672/<vhost>

View File

@ -14,6 +14,8 @@
from pecan import hooks from pecan import hooks
from octavia.api.common import pagination
from octavia.common import constants
from octavia.common import context from octavia.common import context
@ -23,3 +25,14 @@ class ContextHook(hooks.PecanHook):
def on_route(self, state): def on_route(self, state):
context_obj = context.Context.from_environ(state.request.environ) context_obj = context.Context.from_environ(state.request.environ)
state.request.context['octavia_context'] = context_obj state.request.context['octavia_context'] = context_obj
class QueryParametersHook(hooks.PecanHook):
def before(self, state):
if state.request.method != 'GET':
return
state.request.context[
constants.PAGINATION_HELPER] = pagination.PaginationHelper(
state.request.params.mixed())

View File

@ -0,0 +1,269 @@
# Copyright 2016 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import copy
from oslo_log import log as logging
from pecan import request
import sqlalchemy
import sqlalchemy.sql as sa_sql
from octavia.api.common import types
from octavia.common.config import cfg
from octavia.common import constants
from octavia.common import exceptions
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
class PaginationHelper(object):
"""Class helping to interact with pagination functionality
Pass this class to `db.repositories` to apply it on query
"""
def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR):
"""Pagination Helper takes params and a default sort direction
:param params: Contains the following:
limit: maximum number of items to return
marker: the last item of the previous page; we return
the next results after this value.
sort: array of attr by which results should be sorted
:param sort_dir: default direction to sort (asc, desc)
"""
self.marker = params.get('marker')
self.sort_dir = self._validate_sort_dir(sort_dir)
self.limit = self._parse_limit(params)
self.sort_keys = self._parse_sort_keys(params)
self.params = params
@staticmethod
def _parse_limit(params):
if CONF.pagination_max_limit == 'infinite':
page_max_limit = None
else:
page_max_limit = int(CONF.pagination_max_limit)
limit = params.get('limit', page_max_limit)
try:
# Deal with limit being a string or int meaning 'Unlimited'
if limit == 'infinite' or int(limit) < 1:
limit = None
# If we don't have a max, just use whatever limit is specified
elif page_max_limit is None:
limit = int(limit)
# Otherwise, we need to compare against the max
else:
limit = min(int(limit), page_max_limit)
except ValueError:
raise exceptions.InvalidLimit(key=limit)
return limit
def _parse_sort_keys(self, params):
sort_keys_dirs = []
sort = params.get('sort')
sort_keys = params.get('sort_key')
if sort:
for sort_dir_key in sort.split(","):
comps = sort_dir_key.split(":")
if len(comps) == 1: # Use default sort order
sort_keys_dirs.append((comps[0], self.sort_dir))
elif len(comps) == 2:
sort_keys_dirs.append(
(comps[0], self._validate_sort_dir(comps[1])))
else:
raise exceptions.InvalidSortKey(key=comps)
elif sort_keys:
sort_keys = sort_keys.split(',')
sort_dirs = params.get('sort_dir')
if not sort_dirs:
sort_dirs = [self.sort_dir] * len(sort_keys)
else:
sort_dirs = sort_dirs.split(',')
if len(sort_dirs) < len(sort_keys):
sort_dirs += [self.sort_dir] * (len(sort_keys) -
len(sort_dirs))
for sk, sd in zip(sort_keys, sort_dirs):
sort_keys_dirs.append((sk, self._validate_sort_dir(sd)))
return sort_keys_dirs
def _parse_marker(self, session, model):
return session.query(model).filter_by(id=self.marker).one_or_none()
@staticmethod
def _get_default_column_value(column_type):
"""Return the default value of the columns from DB table
In postgreDB case, if no right default values are being set, an
psycopg2.DataError will be thrown.
"""
type_schema = {
'datetime': None,
'big_integer': 0,
'integer': 0,
'string': ''
}
if isinstance(column_type, sa_sql.type_api.Variant):
return PaginationHelper._get_default_column_value(column_type.impl)
return type_schema[column_type.__visit_name__]
@staticmethod
def _validate_sort_dir(sort_dir):
sort_dir = sort_dir.lower()
if sort_dir not in constants.ALLOWED_SORT_DIR:
raise exceptions.InvalidSortDirection(key=sort_dir)
return sort_dir
def _make_links(self, model_list):
if CONF.api_base_uri:
path_url = "{api_base_url}{path}".format(
api_base_url=CONF.api_base_uri.rstrip('/'), path=request.path)
else:
path_url = request.path_url
links = []
if model_list:
prev_attr = ["limit={}".format(self.limit)]
if self.params.get('sort'):
prev_attr.append("sort={}".format(self.params.get('sort')))
if self.params.get('sort_key'):
prev_attr.append("sort_key={}".format(
self.params.get('sort_key')))
next_attr = copy.copy(prev_attr)
if self.marker:
prev_attr.append("marker={}".format(model_list[0].get('id')))
prev_link = {
"rel": "previous",
"href": "{url}?{params}".format(
url=path_url,
params="&".join(prev_attr))
}
links.append(prev_link)
# TODO(rm_work) Do we need to know when there are more vs exact?
# We safely know if we have a full page, but it might include the
# last element or it might not, it is unclear
if len(model_list) >= self.limit:
next_attr.append("marker={}".format(model_list[-1].get('id')))
next_link = {
"rel": "next",
"href": "{url}?{params}".format(
url=path_url,
params="&".join(next_attr))
}
links.append(next_link)
links = [types.PageType(**link) for link in links]
return links
def apply(self, query, model):
"""Returns a query with sorting / pagination criteria added.
Pagination works by requiring a unique sort_key specified by sort_keys.
(If sort_keys is not unique, then we risk looping through values.)
We use the last row in the previous page as the pagination 'marker'.
So we must return values that follow the passed marker in the order.
With a single-valued sort_key, this would be easy: sort_key > X.
With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
the lexicographical ordering:
(k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
We also have to cope with different sort_directions.
Typically, the id of the last row is used as the client-facing
pagination marker, then the actual marker object must be fetched from
the db and passed in to us as marker.
:param query: the query object to which we should add paging/sorting
:param model: the ORM model class
:rtype: sqlalchemy.orm.query.Query
:returns: The query with sorting/pagination added.
"""
# Add sorting
if CONF.allow_sorting:
# Add default sort keys (if they are OK for the model)
keys_only = [k[0] for k in self.sort_keys]
for key in constants.DEFAULT_SORT_KEYS:
if key not in keys_only and hasattr(model, key):
self.sort_keys.append((key, self.sort_dir))
for current_sort_key, current_sort_dir in self.sort_keys:
sort_dir_func = {
constants.ASC: sqlalchemy.asc,
constants.DESC: sqlalchemy.desc,
}[current_sort_dir]
try:
sort_key_attr = getattr(model, current_sort_key)
except AttributeError:
raise exceptions.InvalidSortKey(key=current_sort_key)
query = query.order_by(sort_dir_func(sort_key_attr))
# Add pagination
if CONF.allow_pagination:
default = '' # Default to an empty string if NULL
if self.marker is not None:
marker_object = self._parse_marker(query.session, model)
if not marker_object:
raise exceptions.InvalidMarker(key=self.marker)
marker_values = []
for sort_key, _ in self.sort_keys:
v = getattr(marker_object, sort_key)
if v is None:
v = default
marker_values.append(v)
# Build up an array of sort criteria as in the docstring
criteria_list = []
for i in range(len(self.sort_keys)):
crit_attrs = []
for j in range(i):
model_attr = getattr(model, self.sort_keys[j][0])
default = PaginationHelper._get_default_column_value(
model_attr.property.columns[0].type)
attr = sa_sql.expression.case(
[(model_attr != None, # noqa: E711
model_attr), ], else_=default)
crit_attrs.append((attr == marker_values[j]))
model_attr = getattr(model, self.sort_keys[i][0])
default = PaginationHelper._get_default_column_value(
model_attr.property.columns[0].type)
attr = sa_sql.expression.case(
[(model_attr != None, # noqa: E711
model_attr), ], else_=default)
this_sort_dir = self.sort_keys[i][1]
if this_sort_dir == constants.DESC:
crit_attrs.append((attr < marker_values[i]))
elif this_sort_dir == constants.ASC:
crit_attrs.append((attr > marker_values[i]))
else:
raise exceptions.InvalidSortDirection(
key=this_sort_dir)
criteria = sa_sql.and_(*crit_attrs)
criteria_list.append(criteria)
f = sa_sql.or_(*criteria_list)
query = query.filter(f)
if self.limit is not None:
query = query.limit(self.limit)
model_list = query.all()
links = None
if CONF.allow_pagination:
links = self._make_links(model_list)
return model_list, links

View File

@ -147,3 +147,8 @@ class IdOnlyType(BaseType):
class NameOnlyType(BaseType): class NameOnlyType(BaseType):
name = wtypes.wsattr(wtypes.StringType(max_length=255), mandatory=True) name = wtypes.wsattr(wtypes.StringType(max_length=255), mandatory=True)
class PageType(BaseType):
href = wtypes.StringType()
rel = wtypes.StringType()

View File

@ -19,7 +19,9 @@ from octavia.api.common import hooks
app = { app = {
'root': 'octavia.api.root_controller.RootController', 'root': 'octavia.api.root_controller.RootController',
'modules': ['octavia.api'], 'modules': ['octavia.api'],
'hooks': [hooks.ContextHook()], 'hooks': [
hooks.ContextHook(),
hooks.QueryParametersHook()],
'debug': False 'debug': False
} }

View File

@ -52,7 +52,7 @@ class L7PolicyController(base.BaseController):
def get_all(self): def get_all(self):
"""Lists all l7policies of a listener.""" """Lists all l7policies of a listener."""
context = pecan.request.context.get('octavia_context') context = pecan.request.context.get('octavia_context')
db_l7policies = self.repositories.l7policy.get_all( db_l7policies, _ = self.repositories.l7policy.get_all(
context.session, listener_id=self.listener_id) context.session, listener_id=self.listener_id)
return self._convert_db_to_type(db_l7policies, return self._convert_db_to_type(db_l7policies,
[l7policy_types.L7PolicyResponse]) [l7policy_types.L7PolicyResponse])

View File

@ -52,7 +52,7 @@ class L7RuleController(base.BaseController):
def get_all(self): def get_all(self):
"""Lists all l7rules of a l7policy.""" """Lists all l7rules of a l7policy."""
context = pecan.request.context.get('octavia_context') context = pecan.request.context.get('octavia_context')
db_l7rules = self.repositories.l7rule.get_all( db_l7rules, _ = self.repositories.l7rule.get_all(
context.session, l7policy_id=self.l7policy_id) context.session, l7policy_id=self.l7policy_id)
return self._convert_db_to_type(db_l7rules, return self._convert_db_to_type(db_l7rules,
[l7rule_types.L7RuleResponse]) [l7rule_types.L7RuleResponse])

View File

@ -66,12 +66,16 @@ class ListenersController(base.BaseController):
return self._convert_db_to_type(db_listener, return self._convert_db_to_type(db_listener,
listener_types.ListenerResponse) listener_types.ListenerResponse)
@wsme_pecan.wsexpose([listener_types.ListenerResponse]) @wsme_pecan.wsexpose([listener_types.ListenerResponse],
ignore_extra_args=True)
def get_all(self): def get_all(self):
"""Lists all listeners on a load balancer.""" """Lists all listeners on a load balancer."""
context = pecan.request.context.get('octavia_context') context = pecan.request.context.get('octavia_context')
db_listeners = self.repositories.listener.get_all( pcontext = pecan.request.context
context.session, load_balancer_id=self.load_balancer_id) db_listeners, _ = self.repositories.listener.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
load_balancer_id=self.load_balancer_id)
return self._convert_db_to_type(db_listeners, return self._convert_db_to_type(db_listeners,
[listener_types.ListenerResponse]) [listener_types.ListenerResponse])

View File

@ -54,16 +54,22 @@ class LoadBalancersController(base.BaseController):
lb_types.LoadBalancerResponse) lb_types.LoadBalancerResponse)
@wsme_pecan.wsexpose([lb_types.LoadBalancerResponse], wtypes.text, @wsme_pecan.wsexpose([lb_types.LoadBalancerResponse], wtypes.text,
wtypes.text) wtypes.text, ignore_extra_args=True)
def get_all(self, tenant_id=None, project_id=None): def get_all(self, tenant_id=None, project_id=None):
"""Lists all load balancers.""" """Lists all load balancers."""
# NOTE(blogan): tenant_id and project_id are optional query parameters # NOTE(blogan): tenant_id and project_id are optional query parameters
# tenant_id and project_id are the same thing. tenant_id will be kept # tenant_id and project_id are the same thing. tenant_id will be kept
# around for a short amount of time. # around for a short amount of time.
context = pecan.request.context.get('octavia_context')
pcontext = pecan.request.context
context = pcontext.get('octavia_context')
project_id = context.project_id or project_id or tenant_id project_id = context.project_id or project_id or tenant_id
load_balancers = self.repositories.load_balancer.get_all(
context.session, project_id=project_id) load_balancers, _ = self.repositories.load_balancer.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
project_id=project_id)
return self._convert_db_to_type(load_balancers, return self._convert_db_to_type(load_balancers,
[lb_types.LoadBalancerResponse]) [lb_types.LoadBalancerResponse])

View File

@ -49,12 +49,16 @@ class MembersController(base.BaseController):
db_member = self._get_db_member(context.session, id) db_member = self._get_db_member(context.session, id)
return self._convert_db_to_type(db_member, member_types.MemberResponse) return self._convert_db_to_type(db_member, member_types.MemberResponse)
@wsme_pecan.wsexpose([member_types.MemberResponse]) @wsme_pecan.wsexpose([member_types.MemberResponse], ignore_extra_args=True)
def get_all(self): def get_all(self):
"""Lists all pool members of a pool.""" """Lists all pool members of a pool."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
db_members = self.repositories.member.get_all( context = pcontext.get('octavia_context')
context.session, pool_id=self.pool_id)
db_members, _ = self.repositories.member.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
pool_id=self.pool_id)
return self._convert_db_to_type(db_members, return self._convert_db_to_type(db_members,
[member_types.MemberResponse]) [member_types.MemberResponse])

View File

@ -49,18 +49,23 @@ class PoolsController(base.BaseController):
db_pool = self._get_db_pool(context.session, id) db_pool = self._get_db_pool(context.session, id)
return self._convert_db_to_type(db_pool, pool_types.PoolResponse) return self._convert_db_to_type(db_pool, pool_types.PoolResponse)
@wsme_pecan.wsexpose([pool_types.PoolResponse], wtypes.text) @wsme_pecan.wsexpose([pool_types.PoolResponse], wtypes.text,
ignore_extra_args=True)
def get_all(self, listener_id=None): def get_all(self, listener_id=None):
"""Lists all pools on a listener or loadbalancer.""" """Lists all pools on a listener or loadbalancer."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
context = pcontext.get('octavia_context')
if listener_id is not None: if listener_id is not None:
self.listener_id = listener_id self.listener_id = listener_id
if self.listener_id: if self.listener_id:
pools = self._get_db_listener(context.session, pools = self._get_db_listener(context.session,
self.listener_id).pools self.listener_id).pools
else: else:
pools = self.repositories.pool.get_all( pools, _ = self.repositories.pool.get_all(
context.session, load_balancer_id=self.load_balancer_id) context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
load_balancer_id=self.load_balancer_id)
return self._convert_db_to_type(pools, [pool_types.PoolResponse]) return self._convert_db_to_type(pools, [pool_types.PoolResponse])
def _get_affected_listener_ids(self, session, pool=None): def _get_affected_listener_ids(self, session, pool=None):

View File

@ -42,7 +42,7 @@ class QuotasController(base.BaseController):
def get_all(self): def get_all(self):
"""List all non-default quotas.""" """List all non-default quotas."""
context = pecan.request.context.get('octavia_context') context = pecan.request.context.get('octavia_context')
db_quotas = self.repositories.quotas.get_all(context.session) db_quotas, _ = self.repositories.quotas.get_all(context.session)
quotas = quota_types.QuotaAllResponse.from_data_model(db_quotas) quotas = quota_types.QuotaAllResponse.from_data_model(db_quotas)
return quotas return quotas

View File

@ -62,10 +62,12 @@ class HealthMonitorController(base.BaseController):
db_hm, hm_types.HealthMonitorResponse) db_hm, hm_types.HealthMonitorResponse)
return hm_types.HealthMonitorRootResponse(healthmonitor=result) return hm_types.HealthMonitorRootResponse(healthmonitor=result)
@wsme_pecan.wsexpose(hm_types.HealthMonitorsRootResponse, wtypes.text) @wsme_pecan.wsexpose(hm_types.HealthMonitorsRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None): def get_all(self, project_id=None):
"""Gets all health monitors.""" """Gets all health monitors."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
context = pcontext.get('octavia_context')
if context.is_admin or CONF.auth_strategy == constants.NOAUTH: if context.is_admin or CONF.auth_strategy == constants.NOAUTH:
if project_id: if project_id:
project_id = {'project_id': project_id} project_id = {'project_id': project_id}
@ -73,11 +75,14 @@ class HealthMonitorController(base.BaseController):
project_id = {} project_id = {}
else: else:
project_id = {'project_id': context.project_id} project_id = {'project_id': context.project_id}
db_hm = self.repositories.health_monitor.get_all( db_hm, links = self.repositories.health_monitor.get_all(
context.session, show_deleted=False, **project_id) context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**project_id)
result = self._convert_db_to_type( result = self._convert_db_to_type(
db_hm, [hm_types.HealthMonitorResponse]) db_hm, [hm_types.HealthMonitorResponse])
return hm_types.HealthMonitorsRootResponse(healthmonitors=result) return hm_types.HealthMonitorsRootResponse(
healthmonitors=result, healthmonitors_links=links)
def _get_affected_listener_ids(self, session, hm): def _get_affected_listener_ids(self, session, hm):
"""Gets a list of all listeners this request potentially affects.""" """Gets a list of all listeners this request potentially affects."""

View File

@ -51,10 +51,12 @@ class L7PolicyController(base.BaseController):
l7policy_types.L7PolicyResponse) l7policy_types.L7PolicyResponse)
return l7policy_types.L7PolicyRootResponse(l7policy=result) return l7policy_types.L7PolicyRootResponse(l7policy=result)
@wsme_pecan.wsexpose(l7policy_types.L7PoliciesRootResponse, wtypes.text) @wsme_pecan.wsexpose(l7policy_types.L7PoliciesRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None): def get_all(self, project_id=None):
"""Lists all l7policies of a listener.""" """Lists all l7policies of a listener."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
context = pcontext.get('octavia_context')
if context.is_admin or CONF.auth_strategy == constants.NOAUTH: if context.is_admin or CONF.auth_strategy == constants.NOAUTH:
if project_id: if project_id:
project_id = {'project_id': project_id} project_id = {'project_id': project_id}
@ -62,11 +64,14 @@ class L7PolicyController(base.BaseController):
project_id = {} project_id = {}
else: else:
project_id = {'project_id': context.project_id} project_id = {'project_id': context.project_id}
db_l7policies = self.repositories.l7policy.get_all( db_l7policies, links = self.repositories.l7policy.get_all(
context.session, show_deleted=False, **project_id) context.session, show_deleted=False,
result = self._convert_db_to_type(db_l7policies, pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
[l7policy_types.L7PolicyResponse]) **project_id)
return l7policy_types.L7PoliciesRootResponse(l7policies=result) result = self._convert_db_to_type(
db_l7policies, [l7policy_types.L7PolicyResponse])
return l7policy_types.L7PoliciesRootResponse(
l7policies=result, l7policies_links=links)
def _test_lb_and_listener_statuses(self, session, lb_id, listener_ids): def _test_lb_and_listener_statuses(self, session, lb_id, listener_ids):
"""Verify load balancer is in a mutable state.""" """Verify load balancer is in a mutable state."""

View File

@ -49,15 +49,19 @@ class L7RuleController(base.BaseController):
l7rule_types.L7RuleResponse) l7rule_types.L7RuleResponse)
return l7rule_types.L7RuleRootResponse(rule=result) return l7rule_types.L7RuleRootResponse(rule=result)
@wsme_pecan.wsexpose(l7rule_types.L7RulesRootResponse, wtypes.text) @wsme_pecan.wsexpose(l7rule_types.L7RulesRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self): def get_all(self):
"""Lists all l7rules of a l7policy.""" """Lists all l7rules of a l7policy."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
db_l7rules = self.repositories.l7rule.get_all( context = pcontext.get('octavia_context')
context.session, show_deleted=False, l7policy_id=self.l7policy_id) db_l7rules, links = self.repositories.l7rule.get_all(
result = self._convert_db_to_type(db_l7rules, context.session, show_deleted=False, l7policy_id=self.l7policy_id,
[l7rule_types.L7RuleResponse]) pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
return l7rule_types.L7RulesRootResponse(rules=result) result = self._convert_db_to_type(
db_l7rules, [l7rule_types.L7RuleResponse])
return l7rule_types.L7RulesRootResponse(
rules=result, rules_links=links)
def _test_lb_listener_policy_statuses(self, session): def _test_lb_listener_policy_statuses(self, session):
"""Verify load balancer is in a mutable state.""" """Verify load balancer is in a mutable state."""

View File

@ -64,10 +64,12 @@ class ListenersController(base.BaseController):
listener_types.ListenerResponse) listener_types.ListenerResponse)
return listener_types.ListenerRootResponse(listener=result) return listener_types.ListenerRootResponse(listener=result)
@wsme_pecan.wsexpose(listener_types.ListenersRootResponse, wtypes.text) @wsme_pecan.wsexpose(listener_types.ListenersRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None): def get_all(self, project_id=None):
"""Lists all listeners.""" """Lists all listeners."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
context = pcontext.get('octavia_context')
if context.is_admin or CONF.auth_strategy == constants.NOAUTH: if context.is_admin or CONF.auth_strategy == constants.NOAUTH:
if project_id: if project_id:
project_id = {'project_id': project_id} project_id = {'project_id': project_id}
@ -75,11 +77,14 @@ class ListenersController(base.BaseController):
project_id = {} project_id = {}
else: else:
project_id = {'project_id': context.project_id} project_id = {'project_id': context.project_id}
db_listeners = self.repositories.listener.get_all( db_listeners, links = self.repositories.listener.get_all(
context.session, show_deleted=False, **project_id) context.session, show_deleted=False,
result = self._convert_db_to_type(db_listeners, pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
[listener_types.ListenerResponse]) **project_id)
return listener_types.ListenersRootResponse(listeners=result) result = self._convert_db_to_type(
db_listeners, [listener_types.ListenerResponse])
return listener_types.ListenersRootResponse(
listeners=result, listeners_links=links)
def _test_lb_and_listener_statuses( def _test_lb_and_listener_statuses(
self, session, lb_id, id=None, self, session, lb_id, id=None,

View File

@ -50,14 +50,16 @@ class LoadBalancersController(base.BaseController):
"""Gets a single load balancer's details.""" """Gets a single load balancer's details."""
context = pecan.request.context.get('octavia_context') context = pecan.request.context.get('octavia_context')
load_balancer = self._get_db_lb(context.session, id) load_balancer = self._get_db_lb(context.session, id)
result = self._convert_db_to_type(load_balancer, result = self._convert_db_to_type(
lb_types.LoadBalancerResponse) load_balancer, lb_types.LoadBalancerResponse)
return lb_types.LoadBalancerRootResponse(loadbalancer=result) return lb_types.LoadBalancerRootResponse(loadbalancer=result)
@wsme_pecan.wsexpose(lb_types.LoadBalancersRootResponse, wtypes.text) @wsme_pecan.wsexpose(lb_types.LoadBalancersRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None): def get_all(self, project_id=None):
"""Lists all load balancers.""" """Lists all load balancers."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
context = pcontext.get('octavia_context')
if context.is_admin or CONF.auth_strategy == constants.NOAUTH: if context.is_admin or CONF.auth_strategy == constants.NOAUTH:
if project_id: if project_id:
project_id = {'project_id': project_id} project_id = {'project_id': project_id}
@ -65,11 +67,14 @@ class LoadBalancersController(base.BaseController):
project_id = {} project_id = {}
else: else:
project_id = {'project_id': context.project_id} project_id = {'project_id': context.project_id}
load_balancers = self.repositories.load_balancer.get_all( load_balancers, links = self.repositories.load_balancer.get_all(
context.session, show_deleted=False, **project_id) context.session, show_deleted=False,
result = self._convert_db_to_type(load_balancers, pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
[lb_types.LoadBalancerResponse]) **project_id)
return lb_types.LoadBalancersRootResponse(loadbalancers=result) result = self._convert_db_to_type(
load_balancers, [lb_types.LoadBalancerResponse])
return lb_types.LoadBalancersRootResponse(
loadbalancers=result, loadbalancers_links=links)
def _test_lb_status(self, session, id, lb_status=constants.PENDING_UPDATE): def _test_lb_status(self, session, id, lb_status=constants.PENDING_UPDATE):
"""Verify load balancer is in a mutable state.""" """Verify load balancer is in a mutable state."""

View File

@ -50,15 +50,20 @@ class MembersController(base.BaseController):
member_types.MemberResponse) member_types.MemberResponse)
return member_types.MemberRootResponse(member=result) return member_types.MemberRootResponse(member=result)
@wsme_pecan.wsexpose(member_types.MembersRootResponse, wtypes.text) @wsme_pecan.wsexpose(member_types.MembersRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self): def get_all(self):
"""Lists all pool members of a pool.""" """Lists all pool members of a pool."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
db_members = self.repositories.member.get_all( context = pcontext.get('octavia_context')
context.session, show_deleted=False, pool_id=self.pool_id) db_members, links = self.repositories.member.get_all(
context.session, show_deleted=False,
pool_id=self.pool_id,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
result = self._convert_db_to_type( result = self._convert_db_to_type(
db_members, [member_types.MemberResponse]) db_members, [member_types.MemberResponse])
return member_types.MembersRootResponse(members=result) return member_types.MembersRootResponse(
members=result, members_links=links)
def _get_affected_listener_ids(self, session, member=None): def _get_affected_listener_ids(self, session, member=None):
"""Gets a list of all listeners this request potentially affects.""" """Gets a list of all listeners this request potentially affects."""

View File

@ -31,6 +31,7 @@ from octavia.common import data_models
from octavia.common import exceptions from octavia.common import exceptions
from octavia.db import api as db_api from octavia.db import api as db_api
from octavia.db import prepare as db_prepare from octavia.db import prepare as db_prepare
from octavia.i18n import _
CONF = cfg.CONF CONF = cfg.CONF
@ -51,10 +52,12 @@ class PoolsController(base.BaseController):
result = self._convert_db_to_type(db_pool, pool_types.PoolResponse) result = self._convert_db_to_type(db_pool, pool_types.PoolResponse)
return pool_types.PoolRootResponse(pool=result) return pool_types.PoolRootResponse(pool=result)
@wsme_pecan.wsexpose(pool_types.PoolsRootResponse, wtypes.text) @wsme_pecan.wsexpose(pool_types.PoolsRootResponse, wtypes.text,
ignore_extra_args=True)
def get_all(self, project_id=None): def get_all(self, project_id=None):
"""Lists all pools.""" """Lists all pools."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
context = pcontext.get('octavia_context')
if context.is_admin or CONF.auth_strategy == constants.NOAUTH: if context.is_admin or CONF.auth_strategy == constants.NOAUTH:
if project_id: if project_id:
project_id = {'project_id': project_id} project_id = {'project_id': project_id}
@ -62,10 +65,12 @@ class PoolsController(base.BaseController):
project_id = {} project_id = {}
else: else:
project_id = {'project_id': context.project_id} project_id = {'project_id': context.project_id}
db_pools = self.repositories.pool.get_all( db_pools, links = self.repositories.pool.get_all(
context.session, show_deleted=False, **project_id) context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**project_id)
result = self._convert_db_to_type(db_pools, [pool_types.PoolResponse]) result = self._convert_db_to_type(db_pools, [pool_types.PoolResponse])
return pool_types.PoolsRootResponse(pools=result) return pool_types.PoolsRootResponse(pools=result, pools_links=links)
def _get_affected_listener_ids(self, pool): def _get_affected_listener_ids(self, pool):
"""Gets a list of all listeners this request potentially affects.""" """Gets a list of all listeners this request potentially affects."""

View File

@ -38,12 +38,26 @@ class QuotasController(base.BaseController):
db_quotas = self._get_db_quotas(context.session, project_id) db_quotas = self._get_db_quotas(context.session, project_id)
return self._convert_db_to_type(db_quotas, quota_types.QuotaResponse) return self._convert_db_to_type(db_quotas, quota_types.QuotaResponse)
@wsme_pecan.wsexpose(quota_types.QuotaAllResponse) @wsme_pecan.wsexpose(quota_types.QuotaAllResponse,
def get_all(self): ignore_extra_args=True)
def get_all(self, tenant_id=None, project_id=None):
"""List all non-default quotas.""" """List all non-default quotas."""
context = pecan.request.context.get('octavia_context') pcontext = pecan.request.context
db_quotas = self.repositories.quotas.get_all(context.session) context = pcontext.get('octavia_context')
if context.is_admin or CONF.auth_strategy == constants.NOAUTH:
if project_id or tenant_id:
project_id = {'project_id': project_id or tenant_id}
else:
project_id = {}
else:
project_id = {'project_id': context.project_id}
db_quotas, links = self.repositories.quotas.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**project_id)
quotas = quota_types.QuotaAllResponse.from_data_model(db_quotas) quotas = quota_types.QuotaAllResponse.from_data_model(db_quotas)
quotas.quotas_links = links
return quotas return quotas
@wsme_pecan.wsexpose(quota_types.QuotaResponse, wtypes.text, @wsme_pecan.wsexpose(quota_types.QuotaResponse, wtypes.text,

View File

@ -69,6 +69,7 @@ class HealthMonitorRootResponse(types.BaseType):
class HealthMonitorsRootResponse(types.BaseType): class HealthMonitorsRootResponse(types.BaseType):
healthmonitors = wtypes.wsattr([HealthMonitorResponse]) healthmonitors = wtypes.wsattr([HealthMonitorResponse])
healthmonitors_links = wtypes.wsattr([types.PageType])
class HealthMonitorPOST(BaseHealthMonitorType): class HealthMonitorPOST(BaseHealthMonitorType):

View File

@ -74,6 +74,7 @@ class L7PolicyRootResponse(types.BaseType):
class L7PoliciesRootResponse(types.BaseType): class L7PoliciesRootResponse(types.BaseType):
l7policies = wtypes.wsattr([L7PolicyResponse]) l7policies = wtypes.wsattr([L7PolicyResponse])
l7policies_links = wtypes.wsattr([types.PageType])
class L7PolicyPOST(BaseL7PolicyType): class L7PolicyPOST(BaseL7PolicyType):

View File

@ -56,6 +56,7 @@ class L7RuleRootResponse(types.BaseType):
class L7RulesRootResponse(types.BaseType): class L7RulesRootResponse(types.BaseType):
rules = wtypes.wsattr([L7RuleResponse]) rules = wtypes.wsattr([L7RuleResponse])
rules_links = wtypes.wsattr([types.PageType])
class L7RulePOST(BaseL7Type): class L7RulePOST(BaseL7Type):

View File

@ -87,6 +87,7 @@ class ListenerRootResponse(types.BaseType):
class ListenersRootResponse(types.BaseType): class ListenersRootResponse(types.BaseType):
listeners = wtypes.wsattr([ListenerResponse]) listeners = wtypes.wsattr([ListenerResponse])
listeners_links = wtypes.wsattr([types.PageType])
class ListenerPOST(BaseListenerType): class ListenerPOST(BaseListenerType):

View File

@ -99,6 +99,7 @@ class LoadBalancerFullRootResponse(LoadBalancerRootResponse):
class LoadBalancersRootResponse(types.BaseType): class LoadBalancersRootResponse(types.BaseType):
loadbalancers = wtypes.wsattr([LoadBalancerResponse]) loadbalancers = wtypes.wsattr([LoadBalancerResponse])
loadbalancers_links = wtypes.wsattr([types.PageType])
class LoadBalancerPOST(BaseLoadBalancerType): class LoadBalancerPOST(BaseLoadBalancerType):

View File

@ -60,6 +60,7 @@ class MemberRootResponse(types.BaseType):
class MembersRootResponse(types.BaseType): class MembersRootResponse(types.BaseType):
members = wtypes.wsattr([MemberResponse]) members = wtypes.wsattr([MemberResponse])
members_links = wtypes.wsattr([types.PageType])
class MemberPOST(BaseMemberType): class MemberPOST(BaseMemberType):

View File

@ -109,6 +109,7 @@ class PoolRootResponse(types.BaseType):
class PoolsRootResponse(types.BaseType): class PoolsRootResponse(types.BaseType):
pools = wtypes.wsattr([PoolResponse]) pools = wtypes.wsattr([PoolResponse])
pools_links = wtypes.wsattr([types.PageType])
class PoolPOST(BasePoolType): class PoolPOST(BasePoolType):

View File

@ -62,6 +62,7 @@ class QuotaAllBase(base.BaseType):
class QuotaAllResponse(base.BaseType): class QuotaAllResponse(base.BaseType):
quotas = wtypes.wsattr([QuotaAllBase]) quotas = wtypes.wsattr([QuotaAllBase])
quotas_links = wtypes.wsattr([base.PageType])
@classmethod @classmethod
def from_data_model(cls, data_model, children=False): def from_data_model(cls, data_model, children=False):

View File

@ -41,22 +41,21 @@ core_opts = [
help=_("The auth strategy for API requests.")), help=_("The auth strategy for API requests.")),
cfg.StrOpt('api_handler', default='queue_producer', cfg.StrOpt('api_handler', default='queue_producer',
help=_("The handler that the API communicates with")), help=_("The handler that the API communicates with")),
cfg.StrOpt('api_paste_config', default="api-paste.ini", cfg.BoolOpt('allow_pagination', default=True,
help=_("The API paste config file to use")),
cfg.StrOpt('api_extensions_path', default="",
help=_("The path for API extensions")),
cfg.BoolOpt('allow_bulk', default=True,
help=_("Allow the usage of the bulk API")),
cfg.BoolOpt('allow_pagination', default=False,
help=_("Allow the usage of the pagination")), help=_("Allow the usage of the pagination")),
cfg.BoolOpt('allow_sorting', default=False, cfg.BoolOpt('allow_sorting', default=True,
help=_("Allow the usage of the sorting")), help=_("Allow the usage of the sorting")),
cfg.StrOpt('pagination_max_limit', default="-1", cfg.StrOpt('pagination_max_limit',
default=str(constants.DEFAULT_PAGE_SIZE),
help=_("The maximum number of items returned in a single " help=_("The maximum number of items returned in a single "
"response. The string 'infinite' or a negative " "response. The string 'infinite' or a negative "
"integer value means 'no limit'")), "integer value means 'no limit'")),
cfg.HostnameOpt('host', default=utils.get_hostname(), cfg.HostnameOpt('host', default=utils.get_hostname(),
help=_("The hostname Octavia is running on")), help=_("The hostname Octavia is running on")),
cfg.StrOpt('api_base_uri',
help=_("Base URI for the API for use in pagination links. "
"This will be autodetected from the request if not "
"overridden here.")),
cfg.StrOpt('octavia_plugins', cfg.StrOpt('octavia_plugins',
default='hot_plug_plugin', default='hot_plug_plugin',
help=_('Name of the controller plugin to use')) help=_('Name of the controller plugin to use'))

View File

@ -406,7 +406,18 @@ KEYSTONE = 'keystone'
NOAUTH = 'noauth' NOAUTH = 'noauth'
TESTING = 'testing' TESTING = 'testing'
# Amphora distro-specific data
UBUNTU_AMP_NET_DIR_TEMPLATE = '/etc/netns/{netns}/network/interfaces.d/' UBUNTU_AMP_NET_DIR_TEMPLATE = '/etc/netns/{netns}/network/interfaces.d/'
RH_AMP_NET_DIR_TEMPLATE = '/etc/netns/{netns}/sysconfig/network-scripts/' RH_AMP_NET_DIR_TEMPLATE = '/etc/netns/{netns}/sysconfig/network-scripts/'
UBUNTU = 'ubuntu' UBUNTU = 'ubuntu'
CENTOS = 'centos' CENTOS = 'centos'
# Pagination, sorting, filtering values
APPLICATION_JSON = 'application/json'
PAGINATION_HELPER = 'pagination_helper'
ASC = 'asc'
DESC = 'desc'
ALLOWED_SORT_DIR = (ASC, DESC)
DEFAULT_SORT_DIR = ASC
DEFAULT_SORT_KEYS = ['created_at', 'id']
DEFAULT_PAGE_SIZE = 1000

View File

@ -271,6 +271,11 @@ class MissingProjectID(OctaviaException):
message = _('Missing project ID in request where one is required.') message = _('Missing project ID in request where one is required.')
class MissingAPIProjectID(APIException):
message = _('Missing project ID in request where one is required.')
code = 400
class InvalidSubresource(APIException): class InvalidSubresource(APIException):
msg = _('%(resource)s %(id)s not found.') msg = _('%(resource)s %(id)s not found.')
code = 400 code = 400
@ -279,3 +284,23 @@ class InvalidSubresource(APIException):
class ValidationException(APIException): class ValidationException(APIException):
msg = _('Validation failure: %(detail)s') msg = _('Validation failure: %(detail)s')
code = 400 code = 400
class InvalidSortKey(APIException):
msg = _("Supplied sort key '%(key)s' is not valid.")
code = 400
class InvalidSortDirection(APIException):
msg = _("Supplied sort direction '%(key)s' is not valid.")
code = 400
class InvalidMarker(APIException):
msg = _("Supplied pagination marker '%(key)s' is not valid.")
code = 400
class InvalidLimit(APIException):
msg = _("Supplied pagination limit '%(key)s' is not valid.")
code = 400

View File

@ -31,7 +31,7 @@ class StatsMixin(object):
def get_listener_stats(self, session, listener_id): def get_listener_stats(self, session, listener_id):
"""Gets the listener statistics data_models object.""" """Gets the listener statistics data_models object."""
db_ls = self.listener_stats_repo.get_all( db_ls, _ = self.listener_stats_repo.get_all(
session, listener_id=listener_id) session, listener_id=listener_id)
if not db_ls: if not db_ls:
LOG.warning("Listener Statistics for Listener %s was not found", LOG.warning("Listener Statistics for Listener %s was not found",

View File

@ -69,7 +69,7 @@ class DatabaseCleanup(object):
seconds=CONF.house_keeping.amphora_expiry_age) seconds=CONF.house_keeping.amphora_expiry_age)
session = db_api.get_session() session = db_api.get_session()
amphora = self.amp_repo.get_all(session, status=constants.DELETED) amphora, _ = self.amp_repo.get_all(session, status=constants.DELETED)
for amp in amphora: for amp in amphora:
if self.amp_health_repo.check_amphora_expired(session, amp.id, if self.amp_health_repo.check_amphora_expired(session, amp.id,
@ -84,7 +84,7 @@ class DatabaseCleanup(object):
seconds=CONF.house_keeping.load_balancer_expiry_age) seconds=CONF.house_keeping.load_balancer_expiry_age)
session = db_api.get_session() session = db_api.get_session()
load_balancers = self.lb_repo.get_all( load_balancers, _ = self.lb_repo.get_all(
session, provisioning_status=constants.DELETED) session, provisioning_status=constants.DELETED)
for lb in load_balancers: for lb in load_balancers:

View File

@ -316,7 +316,7 @@ class ControllerWorker(base_taskflow.BaseTaskFlowEngine):
""" """
lb = self._lb_repo.get(db_apis.get_session(), lb = self._lb_repo.get(db_apis.get_session(),
id=load_balancer_id) id=load_balancer_id)
listeners = self._listener_repo.get_all( listeners, _ = self._listener_repo.get_all(
db_apis.get_session(), db_apis.get_session(),
load_balancer_id=load_balancer_id) load_balancer_id=load_balancer_id)

View File

@ -102,21 +102,31 @@ class BaseRepository(object):
return return
return model.to_data_model() return model.to_data_model()
def get_all(self, session, **filters): def get_all(self, session, pagination_helper=None, **filters):
"""Retrieves a list of entities from the database. """Retrieves a list of entities from the database.
:param session: A Sql Alchemy database session. :param session: A Sql Alchemy database session.
:param pagination_helper: Helper to apply pagination and sorting.
:param filters: Filters to decide which entities should be retrieved. :param filters: Filters to decide which entities should be retrieved.
:returns: [octavia.common.data_model] :returns: [octavia.common.data_model]
""" """
deleted = filters.pop('show_deleted', True) deleted = filters.pop('show_deleted', True)
model_list = session.query(self.model_class).filter_by(**filters) query = session.query(self.model_class).filter_by(**filters)
if not deleted: if not deleted:
model_list = model_list.filter( query = query.filter(
self.model_class.provisioning_status != consts.DELETED) self.model_class.provisioning_status != consts.DELETED)
model_list = model_list.all()
if pagination_helper:
model_list, links = pagination_helper.apply(
query, self.model_class)
else:
links = None
model_list = query.all()
data_model_list = [model.to_data_model() for model in model_list] data_model_list = [model.to_data_model() for model in model_list]
return data_model_list return data_model_list, links
def exists(self, session, id): def exists(self, session, id):
"""Determines whether an entity exists in the database by its id. """Determines whether an entity exists in the database by its id.
@ -1149,16 +1159,24 @@ class L7PolicyRepository(BaseRepository):
self._pool_check(session, l7policy.redirect_pool_id, self._pool_check(session, l7policy.redirect_pool_id,
listener.load_balancer_id, listener.project_id) listener.load_balancer_id, listener.project_id)
def get_all(self, session, **filters): def get_all(self, session, pagination_helper=None, **filters):
deleted = filters.pop('show_deleted', True) deleted = filters.pop('show_deleted', True)
l7policy_list = session.query(self.model_class).filter_by(**filters) query = session.query(self.model_class).filter_by(
**filters)
if not deleted: if not deleted:
l7policy_list = l7policy_list.filter( query = query.filter(
self.model_class.provisioning_status != consts.DELETED) self.model_class.provisioning_status != consts.DELETED)
l7policy_list = l7policy_list.order_by(self.model_class.position).all() if pagination_helper:
data_model_list = [p.to_data_model() for p in l7policy_list] model_list, links = pagination_helper.apply(
return data_model_list query, self.model_class)
else:
links = None
model_list = query.order_by(self.model_class.position).all()
data_model_list = [model.to_data_model() for model in model_list]
return data_model_list, links
def update(self, session, id, **model_kwargs): def update(self, session, id, **model_kwargs):
with session.begin(subtransactions=True): with session.begin(subtransactions=True):

View File

@ -20,6 +20,8 @@ import pecan
import pecan.testing import pecan.testing
from octavia.api import config as pconfig from octavia.api import config as pconfig
# needed for tests to function when run independently:
from octavia.common import config # noqa: F401
from octavia.common import constants from octavia.common import constants
from octavia.db import api as db_api from octavia.db import api as db_api
from octavia.db import repositories from octavia.db import repositories
@ -255,8 +257,8 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
self.lb_repo.update(db_api.get_session(), lb_id, self.lb_repo.update(db_api.get_session(), lb_id,
provisioning_status=prov_status, provisioning_status=prov_status,
operating_status=op_status) operating_status=op_status)
lb_listeners = self.listener_repo.get_all(db_api.get_session(), lb_listeners, _ = self.listener_repo.get_all(
load_balancer_id=lb_id) db_api.get_session(), load_balancer_id=lb_id)
for listener in lb_listeners: for listener in lb_listeners:
for pool in listener.pools: for pool in listener.pools:
self.pool_repo.update(db_api.get_session(), pool.id, self.pool_repo.update(db_api.get_session(), pool.id,

View File

@ -0,0 +1,225 @@
# Copyright 2016 Rackspace
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import operator
from oslo_serialization import jsonutils as json
from oslo_utils import uuidutils
from octavia.common import constants
from octavia.tests.functional.api.v1 import base
class TestApiSort(base.BaseAPITest):
def setUp(self):
super(TestApiSort, self).setUp()
self.random_name_desc = [
('b', 'g'), ('h', 'g'), ('b', 'a'), ('c', 'g'),
('g', 'c'), ('h', 'h'), ('a', 'e'), ('g', 'h'),
('g', 'd'), ('e', 'h'), ('h', 'e'), ('b', 'f'),
('b', 'h'), ('a', 'h'), ('g', 'g'), ('h', 'f'),
('c', 'h'), ('g', 'f'), ('f', 'f'), ('d', 'd'),
('g', 'b'), ('a', 'c'), ('h', 'a'), ('h', 'c'),
('e', 'd'), ('d', 'g'), ('c', 'b'), ('f', 'b'),
('c', 'c'), ('d', 'c'), ('f', 'a'), ('h', 'd'),
('f', 'c'), ('d', 'a'), ('d', 'e'), ('d', 'f'),
('g', 'e'), ('a', 'a'), ('e', 'c'), ('e', 'b'),
('f', 'g'), ('d', 'b'), ('e', 'a'), ('b', 'e'),
('f', 'h'), ('a', 'g'), ('c', 'd'), ('b', 'd'),
('b', 'b'), ('a', 'b'), ('f', 'd'), ('f', 'e'),
('c', 'a'), ('b', 'c'), ('e', 'f'), ('a', 'f'),
('e', 'e'), ('h', 'b'), ('d', 'h'), ('e', 'g'),
('c', 'e'), ('g', 'a'), ('a', 'd'), ('c', 'f')]
self.headers = {'accept': constants.APPLICATION_JSON,
'content-type': constants.APPLICATION_JSON}
self.lbs = []
self.lb_names = ['lb_c', 'lb_a', 'lb_b', 'lb_e', 'lb_d']
def _create_loadbalancers(self):
for name in self.lb_names:
lb = self.create_load_balancer(
{'subnet_id': uuidutils.generate_uuid()}, name=name)
self.lbs.append(lb)
def test_lb_keysort(self):
self._create_loadbalancers()
params = {'sort': 'name:desc',
'project_id': self.project_id}
resp = self.get(self.LBS_PATH, params=params,
headers=self.headers)
lbs = json.loads(resp.body)
act_names = [l['name'] for l in lbs]
ref_names = sorted(self.lb_names[:], reverse=True)
self.assertEqual(ref_names, act_names) # Should be in order
def test_loadbalancer_sorting_and_pagination(self):
# Python's stable sort will allow us to simulate the full sorting
# capabilities of the api during testing.
exp_order = self.random_name_desc[:]
exp_order.sort(key=operator.itemgetter(1), reverse=False)
exp_order.sort(key=operator.itemgetter(0), reverse=True)
for (name, desc) in self.random_name_desc:
self.create_load_balancer(
{'subnet_id': uuidutils.generate_uuid()},
name=name, description=desc)
params = {'sort': 'name:desc,description:asc',
'project_id': self.project_id}
# Get all lbs
resp = self.get(self.LBS_PATH, headers=self.headers, params=params)
all_lbs = json.loads(resp.body)
# Test the first 8 which is just limit=8
params.update({'limit': '8'})
resp = self.get(self.LBS_PATH, headers=self.headers, params=params)
lbs = json.loads(resp.body)
fnd_name_descs = [(lb['name'], lb['description']) for lb in lbs]
self.assertEqual(exp_order[0:8], fnd_name_descs)
# Test the slice at 8:24 which is marker=7 limit=16
params.update({'marker': all_lbs[7].get('id'), 'limit': '16'})
resp = self.get(self.LBS_PATH, headers=self.headers, params=params)
lbs = json.loads(resp.body)
fnd_name_descs = [(lb['name'], lb['description']) for lb in lbs]
self.assertEqual(exp_order[8:24], fnd_name_descs)
# Test the slice at 32:56 which is marker=31 limit=24
params.update({'marker': all_lbs[31].get('id'), 'limit': '24'})
resp = self.get(self.LBS_PATH, headers=self.headers, params=params)
lbs = json.loads(resp.body)
fnd_name_descs = [(lb['name'], lb['description']) for lb in lbs]
self.assertEqual(exp_order[32:56], fnd_name_descs)
# Test the last 8 entries which is slice 56:64 marker=55 limit=8
params.update({'marker': all_lbs[55].get('id'), 'limit': '8'})
resp = self.get(self.LBS_PATH, headers=self.headers, params=params)
lbs = json.loads(resp.body)
fnd_name_descs = [(lb['name'], lb['description']) for lb in lbs]
self.assertEqual(exp_order[56:64], fnd_name_descs)
# Test that we don't get an overflow or some other error if
# the number of entries is less then the limit.
# This should only return 4 entries
params.update({'marker': all_lbs[59].get('id'), 'limit': '8'})
resp = self.get(self.LBS_PATH, headers=self.headers, params=params)
lbs = json.loads(resp.body)
fnd_name_descs = [(lb['name'], lb['description']) for lb in lbs]
self.assertEqual(exp_order[60:64], fnd_name_descs)
def test_listeners_sorting_and_pagination(self):
# Create a loadbalancer and create 2 listeners on it
lb = self.create_load_balancer(
{'subnet_id': uuidutils.generate_uuid()}, name="single_lb")
lb_id = lb['id']
self.set_lb_status(lb_id)
exp_desc_names = self.random_name_desc[30:40]
exp_desc_names.sort(key=operator.itemgetter(0), reverse=True)
exp_desc_names.sort(key=operator.itemgetter(1), reverse=True)
port = 0
# We did some heavy testing already and the set_lb_status function
# is recursive and leads to n*(n-1) iterations during this test so
# we only test 10 entries
for (name, description) in self.random_name_desc[30:40]:
port += 1
opts = {"name": name, "description": description}
self.create_listener(lb_id, constants.PROTOCOL_HTTP, port, **opts)
# Set the lb to active but don't recurse the child objects as
# that will create a n*(n-1) operation in this loop
self.set_lb_status(lb_id)
url = self.LISTENERS_PATH.format(lb_id=lb_id)
params = {'sort': 'description:desc,name:desc',
'project_id': self.project_id}
# Get all listeners
resp = self.get(url, headers=self.headers, params=params)
all_listeners = json.loads(resp.body)
# Test the slice at 3:6
params.update({'marker': all_listeners[2].get('id'), 'limit': '3'})
resp = self.get(url, headers=self.headers, params=params)
listeners = json.loads(resp.body)
fnd_name_desc = [(l['name'], l['description']) for l in listeners]
self.assertEqual(exp_desc_names[3:6], fnd_name_desc)
# Test the slice at 1:8
params.update({'marker': all_listeners[0].get('id'), 'limit': '7'})
resp = self.get(url, headers=self.headers, params=params)
listeners = json.loads(resp.body)
fnd_name_desc = [(l['name'], l['description']) for l in listeners]
self.assertEqual(exp_desc_names[1:8], fnd_name_desc)
def test_members_sorting_and_pagination(self):
lb = self.create_load_balancer(
{'subnet_id': uuidutils.generate_uuid()}, name="single_lb")
lb_id = lb['id']
self.set_lb_status(lb_id)
li = self.create_listener(lb_id, constants.PROTOCOL_HTTP, 80)
li_id = li['id']
self.set_lb_status(lb_id)
p = self.create_pool(lb_id, li_id, constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN)
self.set_lb_status(lb_id)
pool_id = p['id']
exp_ip_weights = [('127.0.0.4', 3), ('127.0.0.5', 1), ('127.0.0.2', 5),
('127.0.0.1', 4), ('127.0.0.3', 2)]
for(ip, weight) in exp_ip_weights:
self.create_member(lb_id, pool_id, ip, 80, weight=weight)
self.set_lb_status(lb_id)
exp_ip_weights.sort(key=operator.itemgetter(1))
exp_ip_weights.sort(key=operator.itemgetter(0))
url = self.MEMBERS_PATH.format(lb_id=lb_id, pool_id=pool_id)
params = {'sort': 'ip_address,weight:asc',
'project_id': self.project_id}
# Get all members
resp = self.get(url, headers=self.headers, params=params)
all_members = json.loads(resp.body)
# These tests are getting exhaustive -- just test marker=0 limit=2
params.update({'marker': all_members[0].get('id'), 'limit': '2'})
resp = self.get(url, headers=self.headers, params=params)
members = json.loads(resp.body)
fnd_ip_subs = [(m['ip_address'], m['weight']) for m in members]
self.assertEqual(exp_ip_weights[1:3], fnd_ip_subs)
def test_invalid_limit(self):
params = {'project_id': self.project_id,
'limit': 'a'}
self.get(self.LBS_PATH, headers=self.headers, params=params,
status=400)
def test_invalid_marker(self):
params = {'project_id': self.project_id,
'marker': 'not_a_valid_uuid'}
self.get(self.LBS_PATH, headers=self.headers, params=params,
status=400)
def test_invalid_sort_key(self):
params = {'sort': 'name:desc:asc',
'project_id': self.project_id}
self.get(self.LBS_PATH, headers=self.headers, params=params,
status=400)

View File

@ -257,13 +257,28 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
response = self.post(path, body) response = self.post(path, body)
return response.json return response.json
def create_quota(self, project_id=-1, lb_quota=None, listener_quota=None,
pool_quota=None, hm_quota=None, member_quota=None):
if project_id == -1:
project_id = self.project_id
req_dict = {'load_balancer': lb_quota,
'listener': listener_quota,
'pool': pool_quota,
'health_monitor': hm_quota,
'member': member_quota}
req_dict = {k: v for k, v in req_dict.items() if v is not None}
body = {'quota': req_dict}
path = self.QUOTA_PATH.format(project_id=project_id)
response = self.put(path, body, status=202)
return response.json
def _set_lb_and_children_statuses(self, lb_id, prov_status, op_status, def _set_lb_and_children_statuses(self, lb_id, prov_status, op_status,
autodetect=True): autodetect=True):
self.set_object_status(self.lb_repo, lb_id, self.set_object_status(self.lb_repo, lb_id,
provisioning_status=prov_status, provisioning_status=prov_status,
operating_status=op_status) operating_status=op_status)
lb_listeners = self.listener_repo.get_all(db_api.get_session(), lb_listeners, _ = self.listener_repo.get_all(
load_balancer_id=lb_id) db_api.get_session(), load_balancer_id=lb_id)
for listener in lb_listeners: for listener in lb_listeners:
if autodetect and (listener.provisioning_status == if autodetect and (listener.provisioning_status ==
constants.PENDING_DELETE): constants.PENDING_DELETE):
@ -273,8 +288,8 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
self.set_object_status(self.listener_repo, listener.id, self.set_object_status(self.listener_repo, listener.id,
provisioning_status=listener_prov, provisioning_status=listener_prov,
operating_status=op_status) operating_status=op_status)
lb_l7policies = self.l7policy_repo.get_all(db_api.get_session(), lb_l7policies, _ = self.l7policy_repo.get_all(
listener_id=listener.id) db_api.get_session(), listener_id=listener.id)
for l7policy in lb_l7policies: for l7policy in lb_l7policies:
if autodetect and (l7policy.provisioning_status == if autodetect and (l7policy.provisioning_status ==
constants.PENDING_DELETE): constants.PENDING_DELETE):
@ -284,8 +299,8 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
self.set_object_status(self.l7policy_repo, l7policy.id, self.set_object_status(self.l7policy_repo, l7policy.id,
provisioning_status=l7policy_prov, provisioning_status=l7policy_prov,
operating_status=op_status) operating_status=op_status)
l7rules = self.l7rule_repo.get_all(db_api.get_session(), l7rules, _ = self.l7rule_repo.get_all(
l7policy_id=l7policy.id) db_api.get_session(), l7policy_id=l7policy.id)
for l7rule in l7rules: for l7rule in l7rules:
if autodetect and (l7rule.provisioning_status == if autodetect and (l7rule.provisioning_status ==
constants.PENDING_DELETE): constants.PENDING_DELETE):
@ -295,7 +310,7 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
self.set_object_status(self.l7rule_repo, l7rule.id, self.set_object_status(self.l7rule_repo, l7rule.id,
provisioning_status=l7rule_prov, provisioning_status=l7rule_prov,
operating_status=op_status) operating_status=op_status)
lb_pools = self.pool_repo.get_all(db_api.get_session(), lb_pools, _ = self.pool_repo.get_all(db_api.get_session(),
load_balancer_id=lb_id) load_balancer_id=lb_id)
for pool in lb_pools: for pool in lb_pools:
if autodetect and (pool.provisioning_status == if autodetect and (pool.provisioning_status ==

View File

@ -12,9 +12,12 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import mock
from oslo_utils import uuidutils from oslo_utils import uuidutils
from octavia.common import constants from octavia.common import constants
import octavia.common.context
from octavia.common import data_models from octavia.common import data_models
from octavia.tests.functional.api.v2 import base from octavia.tests.functional.api.v2 import base
@ -89,6 +92,240 @@ class TestHealthMonitor(base.BaseAPITest):
self.assertEqual(1, len(hms)) self.assertEqual(1, len(hms))
self.assertEqual(api_hm.get('id'), hms[0].get('id')) self.assertEqual(api_hm.get('id'), hms[0].get('id'))
def test_get_all_admin(self):
project_id = uuidutils.generate_uuid()
lb1 = self.create_load_balancer(uuidutils.generate_uuid(), name='lb1',
project_id=project_id)
lb1_id = lb1.get('loadbalancer').get('id')
self.set_lb_status(lb1_id)
pool1 = self.create_pool(
lb1_id, constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb1_id)
pool2 = self.create_pool(
lb1_id, constants.PROTOCOL_HTTPS,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb1_id)
pool3 = self.create_pool(
lb1_id, constants.PROTOCOL_TCP,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb1_id)
hm1 = self.create_health_monitor(
pool1.get('id'), constants.HEALTH_MONITOR_HTTP,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb1_id)
hm2 = self.create_health_monitor(
pool2.get('id'), constants.HEALTH_MONITOR_PING,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb1_id)
hm3 = self.create_health_monitor(
pool3.get('id'), constants.HEALTH_MONITOR_TCP,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb1_id)
hms = self.get(self.HMS_PATH).json.get(self.root_tag_list)
self.assertEqual(3, len(hms))
hm_id_protocols = [(hm.get('id'), hm.get('type')) for hm in hms]
self.assertIn((hm1.get('id'), hm1.get('type')), hm_id_protocols)
self.assertIn((hm2.get('id'), hm2.get('type')), hm_id_protocols)
self.assertIn((hm3.get('id'), hm3.get('type')), hm_id_protocols)
def test_get_all_non_admin(self):
project_id = uuidutils.generate_uuid()
lb1 = self.create_load_balancer(uuidutils.generate_uuid(), name='lb1',
project_id=project_id)
lb1_id = lb1.get('loadbalancer').get('id')
self.set_lb_status(lb1_id)
pool1 = self.create_pool(
lb1_id, constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb1_id)
pool2 = self.create_pool(
lb1_id, constants.PROTOCOL_HTTPS,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb1_id)
self.create_health_monitor(
pool1.get('id'), constants.HEALTH_MONITOR_HTTP,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb1_id)
self.create_health_monitor(
pool2.get('id'), constants.HEALTH_MONITOR_PING,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb1_id)
hm3 = self.create_health_monitor(
self.pool_id, constants.HEALTH_MONITOR_TCP,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(self.lb_id)
auth_strategy = self.conf.conf.get('auth_strategy')
self.conf.config(auth_strategy=constants.KEYSTONE)
with mock.patch.object(octavia.common.context.Context, 'project_id',
hm3['project_id']):
hms = self.get(self.HMS_PATH).json.get(self.root_tag_list)
self.conf.config(auth_strategy=auth_strategy)
self.assertEqual(1, len(hms))
hm_id_protocols = [(hm.get('id'), hm.get('type')) for hm in hms]
self.assertIn((hm3.get('id'), hm3.get('type')), hm_id_protocols)
def test_get_by_project_id(self):
project1_id = uuidutils.generate_uuid()
project2_id = uuidutils.generate_uuid()
lb1 = self.create_load_balancer(uuidutils.generate_uuid(), name='lb1',
project_id=project1_id)
lb1_id = lb1.get('loadbalancer').get('id')
self.set_lb_status(lb1_id)
lb2 = self.create_load_balancer(uuidutils.generate_uuid(), name='lb2',
project_id=project2_id)
lb2_id = lb2.get('loadbalancer').get('id')
self.set_lb_status(lb2_id)
pool1 = self.create_pool(
lb1_id, constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb1_id)
pool2 = self.create_pool(
lb1_id, constants.PROTOCOL_HTTPS,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb1_id)
pool3 = self.create_pool(
lb2_id, constants.PROTOCOL_TCP,
constants.LB_ALGORITHM_ROUND_ROBIN).get('pool')
self.set_lb_status(lb2_id)
hm1 = self.create_health_monitor(
pool1.get('id'), constants.HEALTH_MONITOR_HTTP,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb1_id)
hm2 = self.create_health_monitor(
pool2.get('id'), constants.HEALTH_MONITOR_PING,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb1_id)
hm3 = self.create_health_monitor(
pool3.get('id'), constants.HEALTH_MONITOR_TCP,
1, 1, 1, 1).get(self.root_tag)
self.set_lb_status(lb2_id)
hms = self.get(
self.HMS_PATH,
params={'project_id': project1_id}).json.get(self.root_tag_list)
self.assertEqual(2, len(hms))
hm_id_protocols = [(hm.get('id'), hm.get('type')) for hm in hms]
self.assertIn((hm1.get('id'), hm1.get('type')), hm_id_protocols)
self.assertIn((hm2.get('id'), hm2.get('type')), hm_id_protocols)
hms = self.get(
self.HMS_PATH,
params={'project_id': project2_id}).json.get(self.root_tag_list)
self.assertEqual(1, len(hms))
hm_id_protocols = [(hm.get('id'), hm.get('type')) for hm in hms]
self.assertIn((hm3.get('id'), hm3.get('type')), hm_id_protocols)
def test_get_all_sorted(self):
pool1 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1').get('pool')
self.set_lb_status(self.lb_id)
pool2 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2').get('pool')
self.set_lb_status(self.lb_id)
pool3 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3').get('pool')
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool1.get('id'), constants.HEALTH_MONITOR_HTTP,
1, 1, 1, 1, name='hm1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool2.get('id'), constants.HEALTH_MONITOR_PING,
1, 1, 1, 1, name='hm2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool3.get('id'), constants.HEALTH_MONITOR_TCP,
1, 1, 1, 1, name='hm3').get(self.root_tag)
self.set_lb_status(self.lb_id)
response = self.get(self.HMS_PATH, params={'sort': 'name:desc'})
hms_desc = response.json.get(self.root_tag_list)
response = self.get(self.HMS_PATH, params={'sort': 'name:asc'})
hms_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(hms_desc))
self.assertEqual(3, len(hms_asc))
hm_id_names_desc = [(hm.get('id'), hm.get('name')) for hm in hms_desc]
hm_id_names_asc = [(hm.get('id'), hm.get('name')) for hm in hms_asc]
self.assertEqual(hm_id_names_asc, list(reversed(hm_id_names_desc)))
def test_get_all_limited(self):
pool1 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1').get('pool')
self.set_lb_status(self.lb_id)
pool2 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2').get('pool')
self.set_lb_status(self.lb_id)
pool3 = self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3').get('pool')
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool1.get('id'), constants.HEALTH_MONITOR_HTTP,
1, 1, 1, 1, name='hm1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool2.get('id'), constants.HEALTH_MONITOR_PING,
1, 1, 1, 1, name='hm2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_health_monitor(
pool3.get('id'), constants.HEALTH_MONITOR_TCP,
1, 1, 1, 1, name='hm3').get(self.root_tag)
self.set_lb_status(self.lb_id)
# First two -- should have 'next' link
first_two = self.get(self.HMS_PATH, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.HMS_PATH, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.HMS_PATH, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_empty_get_all(self):
response = self.get(self.HMS_PATH).json.get(self.root_tag_list)
self.assertIsInstance(response, list)
self.assertEqual(0, len(response))
def test_create_http_monitor_with_relative_path(self): def test_create_http_monitor_with_relative_path(self):
api_hm = self.create_health_monitor( api_hm = self.create_health_monitor(
self.pool_id, constants.HEALTH_MONITOR_HTTP, self.pool_id, constants.HEALTH_MONITOR_HTTP,

View File

@ -209,6 +209,83 @@ class TestL7Policy(base.BaseAPITest):
self.assertIn((api_l7p_c.get('id'), api_l7p_c.get('action')), self.assertIn((api_l7p_c.get('id'), api_l7p_c.get('action')),
policy_id_actions) policy_id_actions)
def test_get_all_sorted(self):
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REJECT,
name='policy3').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_POOL,
position=2, redirect_pool_id=self.pool_id,
name='policy2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_URL,
redirect_url='http://localhost/',
name='policy1').get(self.root_tag)
self.set_lb_status(self.lb_id)
response = self.get(self.L7POLICIES_PATH,
params={'sort': 'position:desc'})
policies_desc = response.json.get(self.root_tag_list)
response = self.get(self.L7POLICIES_PATH,
params={'sort': 'position:asc'})
policies_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(policies_desc))
self.assertEqual(3, len(policies_asc))
policy_id_names_desc = [(policy.get('id'), policy.get('position'))
for policy in policies_desc]
policy_id_names_asc = [(policy.get('id'), policy.get('position'))
for policy in policies_asc]
self.assertEqual(policy_id_names_asc,
list(reversed(policy_id_names_desc)))
def test_get_all_limited(self):
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REJECT,
name='policy1').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_POOL,
position=2, redirect_pool_id=self.pool_id,
name='policy2').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7policy(
self.listener_id, constants.L7POLICY_ACTION_REDIRECT_TO_URL,
redirect_url='http://localhost/',
name='policy3').get(self.root_tag)
self.set_lb_status(self.lb_id)
# First two -- should have 'next' link
first_two = self.get(self.L7POLICIES_PATH, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.L7POLICIES_PATH, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.L7POLICIES_PATH, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_empty_get_all(self): def test_empty_get_all(self):
response = self.get(self.L7POLICIES_PATH).json.get(self.root_tag_list) response = self.get(self.L7POLICIES_PATH).json.get(self.root_tag_list)
self.assertIsInstance(response, list) self.assertIsInstance(response, list)

View File

@ -96,6 +96,85 @@ class TestL7Rule(base.BaseAPITest):
self.assertIn((api_l7r_b.get('id'), api_l7r_b.get('type')), self.assertIn((api_l7r_b.get('id'), api_l7r_b.get('type')),
rule_id_types) rule_id_types)
def test_get_all_sorted(self):
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_PATH,
constants.L7RULE_COMPARE_TYPE_STARTS_WITH,
'/api').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_COOKIE,
constants.L7RULE_COMPARE_TYPE_CONTAINS, 'some-value',
key='some-cookie').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_HOST_NAME,
constants.L7RULE_COMPARE_TYPE_EQUAL_TO,
'www.example.com').get(self.root_tag)
self.set_lb_status(self.lb_id)
response = self.get(self.l7rules_path,
params={'sort': 'type:desc'})
rules_desc = response.json.get(self.root_tag_list)
response = self.get(self.l7rules_path,
params={'sort': 'type:asc'})
rules_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(rules_desc))
self.assertEqual(3, len(rules_asc))
rule_id_types_desc = [(rule.get('id'), rule.get('type'))
for rule in rules_desc]
rule_id_types_asc = [(rule.get('id'), rule.get('type'))
for rule in rules_asc]
self.assertEqual(rule_id_types_asc,
list(reversed(rule_id_types_desc)))
def test_get_all_limited(self):
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_PATH,
constants.L7RULE_COMPARE_TYPE_STARTS_WITH,
'/api').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_COOKIE,
constants.L7RULE_COMPARE_TYPE_CONTAINS, 'some-value',
key='some-cookie').get(self.root_tag)
self.set_lb_status(self.lb_id)
self.create_l7rule(
self.l7policy_id, constants.L7RULE_TYPE_HOST_NAME,
constants.L7RULE_COMPARE_TYPE_EQUAL_TO,
'www.example.com').get(self.root_tag)
self.set_lb_status(self.lb_id)
# First two -- should have 'next' link
first_two = self.get(self.l7rules_path, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.l7rules_path, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.l7rules_path, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_empty_get_all(self): def test_empty_get_all(self):
response = self.get(self.l7rules_path).json.get(self.root_tag_list) response = self.get(self.l7rules_path).json.get(self.root_tag_list)
self.assertIsInstance(response, list) self.assertIsInstance(response, list)

View File

@ -27,6 +27,7 @@ class TestListener(base.BaseAPITest):
root_tag = 'listener' root_tag = 'listener'
root_tag_list = 'listeners' root_tag_list = 'listeners'
root_tag_links = 'listeners_links'
def setUp(self): def setUp(self):
super(TestListener, self).setUp() super(TestListener, self).setUp()
@ -137,6 +138,78 @@ class TestListener(base.BaseAPITest):
self.assertIn((listener3.get('id'), listener3.get('protocol_port')), self.assertIn((listener3.get('id'), listener3.get('protocol_port')),
listener_id_ports) listener_id_ports)
def test_get_all_sorted(self):
self.create_listener(constants.PROTOCOL_HTTP, 80,
self.lb_id,
name='listener1')
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP, 81,
self.lb_id,
name='listener2')
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP, 82,
self.lb_id,
name='listener3')
self.set_lb_status(self.lb_id)
response = self.get(self.LISTENERS_PATH,
params={'sort': 'name:desc'})
listeners_desc = response.json.get(self.root_tag_list)
response = self.get(self.LISTENERS_PATH,
params={'sort': 'name:asc'})
listeners_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(listeners_desc))
self.assertEqual(3, len(listeners_asc))
listener_id_names_desc = [(listener.get('id'), listener.get('name'))
for listener in listeners_desc]
listener_id_names_asc = [(listener.get('id'), listener.get('name'))
for listener in listeners_asc]
self.assertEqual(listener_id_names_asc,
list(reversed(listener_id_names_desc)))
def test_get_all_limited(self):
self.create_listener(constants.PROTOCOL_HTTP, 80,
self.lb_id,
name='listener1')
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP, 81,
self.lb_id,
name='listener2')
self.set_lb_status(self.lb_id)
self.create_listener(constants.PROTOCOL_HTTP, 82,
self.lb_id,
name='listener3')
self.set_lb_status(self.lb_id)
# First two -- should have 'next' link
first_two = self.get(self.LISTENERS_PATH, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.LISTENERS_PATH, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.LISTENERS_PATH, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get(self): def test_get(self):
listener = self.create_listener( listener = self.create_listener(
constants.PROTOCOL_HTTP, 80, self.lb_id).get(self.root_tag) constants.PROTOCOL_HTTP, 80, self.lb_id).get(self.root_tag)

View File

@ -28,6 +28,7 @@ from octavia.tests.functional.api.v2 import base
class TestLoadBalancer(base.BaseAPITest): class TestLoadBalancer(base.BaseAPITest):
root_tag = 'loadbalancer' root_tag = 'loadbalancer'
root_tag_list = 'loadbalancers' root_tag_list = 'loadbalancers'
root_tag_links = 'loadbalancers_links'
def _assert_request_matches_response(self, req, resp, **optionals): def _assert_request_matches_response(self, req, resp, **optionals):
self.assertTrue(uuidutils.is_uuid_like(resp.get('id'))) self.assertTrue(uuidutils.is_uuid_like(resp.get('id')))
@ -309,6 +310,69 @@ class TestLoadBalancer(base.BaseAPITest):
self.assertEqual(1, len(lbs)) self.assertEqual(1, len(lbs))
self.assertIn((lb3.get('id'), lb3.get('name')), lb_id_names) self.assertIn((lb3.get('id'), lb3.get('name')), lb_id_names)
def test_get_all_sorted(self):
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb1',
project_id=self.project_id)
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb2',
project_id=self.project_id)
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb3',
project_id=self.project_id)
response = self.get(self.LBS_PATH,
params={'sort': 'name:desc'})
lbs_desc = response.json.get(self.root_tag_list)
response = self.get(self.LBS_PATH,
params={'sort': 'name:asc'})
lbs_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(lbs_desc))
self.assertEqual(3, len(lbs_asc))
lb_id_names_desc = [(lb.get('id'), lb.get('name')) for lb in lbs_desc]
lb_id_names_asc = [(lb.get('id'), lb.get('name')) for lb in lbs_asc]
self.assertEqual(lb_id_names_asc, list(reversed(lb_id_names_desc)))
def test_get_all_limited(self):
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb1',
project_id=self.project_id)
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb2',
project_id=self.project_id)
self.create_load_balancer(uuidutils.generate_uuid(),
name='lb3',
project_id=self.project_id)
# First two -- should have 'next' link
first_two = self.get(self.LBS_PATH, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.LBS_PATH, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.LBS_PATH, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get(self): def test_get(self):
project_id = uuidutils.generate_uuid() project_id = uuidutils.generate_uuid()
subnet = network_models.Subnet(id=uuidutils.generate_uuid()) subnet = network_models.Subnet(id=uuidutils.generate_uuid())

View File

@ -103,6 +103,67 @@ class TestMember(base.BaseAPITest):
for m in [api_m_1, api_m_2]: for m in [api_m_1, api_m_2]:
self.assertIn(m, response) self.assertIn(m, response)
def test_get_all_sorted(self):
self.create_member(self.pool_id, '10.0.0.1', 80, name='member1')
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id, '10.0.0.2', 80, name='member2')
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id, '10.0.0.3', 80, name='member3')
self.set_lb_status(self.lb_id)
response = self.get(self.members_path,
params={'sort': 'name:desc'})
members_desc = response.json.get(self.root_tag_list)
response = self.get(self.members_path,
params={'sort': 'name:asc'})
members_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(members_desc))
self.assertEqual(3, len(members_asc))
member_id_names_desc = [(member.get('id'), member.get('name'))
for member in members_desc]
member_id_names_asc = [(member.get('id'), member.get('name'))
for member in members_asc]
self.assertEqual(member_id_names_asc,
list(reversed(member_id_names_desc)))
def test_get_all_limited(self):
self.create_member(self.pool_id, '10.0.0.1', 80, name='member1')
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id, '10.0.0.2', 80, name='member2')
self.set_lb_status(self.lb_id)
self.create_member(self.pool_id, '10.0.0.3', 80, name='member3')
self.set_lb_status(self.lb_id)
# First two -- should have 'next' link
first_two = self.get(self.members_path, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.members_path, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.members_path, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_empty_get_all(self): def test_empty_get_all(self):
response = self.get(self.members_path).json.get(self.root_tag_list) response = self.get(self.members_path).json.get(self.root_tag_list)
self.assertIsInstance(response, list) self.assertIsInstance(response, list)

View File

@ -206,6 +206,91 @@ class TestPool(base.BaseAPITest):
self.assertEqual(1, len(response)) self.assertEqual(1, len(response))
self.assertEqual(api_pool.get('id'), response[0].get('id')) self.assertEqual(api_pool.get('id'), response[0].get('id'))
def test_get_all_sorted(self):
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1')
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2')
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3')
self.set_lb_status(lb_id=self.lb_id)
response = self.get(self.POOLS_PATH,
params={'sort': 'name:desc'})
pools_desc = response.json.get(self.root_tag_list)
response = self.get(self.POOLS_PATH,
params={'sort': 'name:asc'})
pools_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(pools_desc))
self.assertEqual(3, len(pools_asc))
pool_id_names_desc = [(pool.get('id'), pool.get('name'))
for pool in pools_desc]
pool_id_names_asc = [(pool.get('id'), pool.get('name'))
for pool in pools_asc]
self.assertEqual(pool_id_names_asc,
list(reversed(pool_id_names_desc)))
def test_get_all_limited(self):
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool1')
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool2')
self.set_lb_status(lb_id=self.lb_id)
self.create_pool(
self.lb_id,
constants.PROTOCOL_HTTP,
constants.LB_ALGORITHM_ROUND_ROBIN,
name='pool3')
self.set_lb_status(lb_id=self.lb_id)
# First two -- should have 'next' link
first_two = self.get(self.POOLS_PATH, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.POOLS_PATH, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.POOLS_PATH, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_empty_get_all(self): def test_empty_get_all(self):
response = self.get(self.POOLS_PATH).json.get(self.root_tag_list) response = self.get(self.POOLS_PATH).json.get(self.root_tag_list)
self.assertIsInstance(response, list) self.assertIsInstance(response, list)

View File

@ -12,13 +12,15 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import mock
import random import random
from oslo_config import cfg from oslo_config import cfg
from oslo_config import fixture as oslo_fixture from oslo_config import fixture as oslo_fixture
from oslo_utils import uuidutils from oslo_utils import uuidutils
from octavia.common import constants as const from octavia.common import constants
import octavia.common.context
from octavia.tests.functional.api.v2 import base from octavia.tests.functional.api.v2 import base
CONF = cfg.CONF CONF = cfg.CONF
@ -26,26 +28,32 @@ CONF = cfg.CONF
class TestQuotas(base.BaseAPITest): class TestQuotas(base.BaseAPITest):
root_tag = 'quota'
root_tag_list = 'quotas'
root_tag_links = 'quotas_links'
def setUp(self): def setUp(self):
super(TestQuotas, self).setUp() super(TestQuotas, self).setUp()
conf = self.useFixture(oslo_fixture.Config(cfg.CONF)) conf = self.useFixture(oslo_fixture.Config(cfg.CONF))
conf.config( conf.config(
group="quotas", group="quotas",
default_load_balancer_quota=random.randrange(const.QUOTA_UNLIMITED, default_load_balancer_quota=random.randrange(
9000)) constants.QUOTA_UNLIMITED, 9000))
conf.config( conf.config(
group="quotas", group="quotas",
default_listener_quota=random.randrange(const.QUOTA_UNLIMITED, default_listener_quota=random.randrange(
9000)) constants.QUOTA_UNLIMITED, 9000))
conf.config( conf.config(
group="quotas", group="quotas",
default_member_quota=random.randrange(const.QUOTA_UNLIMITED, 9000)) default_member_quota=random.randrange(
constants.QUOTA_UNLIMITED, 9000))
# We need to make sure unlimited gets tested each pass # We need to make sure unlimited gets tested each pass
conf.config(group="quotas", default_pool_quota=const.QUOTA_UNLIMITED) conf.config(group="quotas",
default_pool_quota=constants.QUOTA_UNLIMITED)
conf.config( conf.config(
group="quotas", group="quotas",
default_health_monitor_quota=random.randrange( default_health_monitor_quota=random.randrange(
const.QUOTA_UNLIMITED, 9000)) constants.QUOTA_UNLIMITED, 9000))
self.project_id = uuidutils.generate_uuid() self.project_id = uuidutils.generate_uuid()
@ -68,13 +76,13 @@ class TestQuotas(base.BaseAPITest):
def test_get_all_quotas_no_quotas(self): def test_get_all_quotas_no_quotas(self):
response = self.get(self.QUOTAS_PATH) response = self.get(self.QUOTAS_PATH)
quota_list = response.json quota_list = response.json
self.assertEqual({'quotas': []}, quota_list) self.assertEqual({'quotas': [], 'quotas_links': []}, quota_list)
def test_get_all_quotas_with_quotas(self): def test_get_all_quotas_with_quotas(self):
project_id1 = uuidutils.generate_uuid() project_id1 = uuidutils.generate_uuid()
project_id2 = uuidutils.generate_uuid() project_id2 = uuidutils.generate_uuid()
quota_path1 = self.QUOTA_PATH.format(project_id=project_id1) quota_path1 = self.QUOTA_PATH.format(project_id=project_id1)
quota1 = {'load_balancer': const.QUOTA_UNLIMITED, 'listener': 30, quota1 = {'load_balancer': constants.QUOTA_UNLIMITED, 'listener': 30,
'pool': 30, 'health_monitor': 30, 'member': 30} 'pool': 30, 'health_monitor': 30, 'member': 30}
body1 = {'quota': quota1} body1 = {'quota': quota1}
self.put(quota_path1, body1, status=202) self.put(quota_path1, body1, status=202)
@ -89,9 +97,153 @@ class TestQuotas(base.BaseAPITest):
quota1['project_id'] = project_id1 quota1['project_id'] = project_id1
quota2['project_id'] = project_id2 quota2['project_id'] = project_id2
expected = {'quotas': [quota1, quota2]} expected = {'quotas': [quota1, quota2], 'quotas_links': []}
self.assertEqual(expected, quota_list) self.assertEqual(expected, quota_list)
def test_get_all_admin(self):
project_id1 = uuidutils.generate_uuid()
project_id2 = uuidutils.generate_uuid()
project_id3 = uuidutils.generate_uuid()
quota1 = self.create_quota(
project_id=project_id1, lb_quota=1, member_quota=1
).get(self.root_tag)
quota2 = self.create_quota(
project_id=project_id2, lb_quota=2, member_quota=2
).get(self.root_tag)
quota3 = self.create_quota(
project_id=project_id3, lb_quota=3, member_quota=3
).get(self.root_tag)
quotas = self.get(self.QUOTAS_PATH).json.get(self.root_tag_list)
self.assertEqual(3, len(quotas))
quota_lb_member_quotas = [(l.get('load_balancer'), l.get('member'))
for l in quotas]
self.assertIn((quota1.get('load_balancer'), quota1.get('member')),
quota_lb_member_quotas)
self.assertIn((quota2.get('load_balancer'), quota2.get('member')),
quota_lb_member_quotas)
self.assertIn((quota3.get('load_balancer'), quota3.get('member')),
quota_lb_member_quotas)
def test_get_all_non_admin(self):
project1_id = uuidutils.generate_uuid()
project2_id = uuidutils.generate_uuid()
project3_id = uuidutils.generate_uuid()
self.create_quota(
project_id=project1_id, lb_quota=1, member_quota=1
).get(self.root_tag)
self.create_quota(
project_id=project2_id, lb_quota=2, member_quota=2
).get(self.root_tag)
quota3 = self.create_quota(
project_id=project3_id, lb_quota=3, member_quota=3
).get(self.root_tag)
auth_strategy = self.conf.conf.get('auth_strategy')
self.conf.config(auth_strategy=constants.KEYSTONE)
with mock.patch.object(octavia.common.context.Context, 'project_id',
project3_id):
quotas = self.get(self.QUOTAS_PATH).json.get(self.root_tag_list)
self.conf.config(auth_strategy=auth_strategy)
self.assertEqual(1, len(quotas))
quota_lb_member_quotas = [(l.get('load_balancer'), l.get('member'))
for l in quotas]
self.assertIn((quota3.get('load_balancer'), quota3.get('member')),
quota_lb_member_quotas)
def test_get_by_project_id(self):
project1_id = uuidutils.generate_uuid()
project2_id = uuidutils.generate_uuid()
quota1 = self.create_quota(
project_id=project1_id, lb_quota=1, member_quota=1
).get(self.root_tag)
quota2 = self.create_quota(
project_id=project2_id, lb_quota=2, member_quota=2
).get(self.root_tag)
quotas = self.get(
self.QUOTA_PATH.format(project_id=project1_id)
).json.get(self.root_tag)
self._assert_quotas_equal(quotas, quota1)
quotas = self.get(
self.QUOTA_PATH.format(project_id=project2_id)
).json.get(self.root_tag)
self._assert_quotas_equal(quotas, quota2)
def test_get_all_sorted(self):
project1_id = uuidutils.generate_uuid()
project2_id = uuidutils.generate_uuid()
project3_id = uuidutils.generate_uuid()
self.create_quota(
project_id=project1_id, lb_quota=3, member_quota=8
).get(self.root_tag)
self.create_quota(
project_id=project2_id, lb_quota=2, member_quota=10
).get(self.root_tag)
self.create_quota(
project_id=project3_id, lb_quota=1, member_quota=9
).get(self.root_tag)
response = self.get(self.QUOTAS_PATH,
params={'sort': 'load_balancer:desc'})
quotas_desc = response.json.get(self.root_tag_list)
response = self.get(self.QUOTAS_PATH,
params={'sort': 'load_balancer:asc'})
quotas_asc = response.json.get(self.root_tag_list)
self.assertEqual(3, len(quotas_desc))
self.assertEqual(3, len(quotas_asc))
quota_lb_member_desc = [(l.get('load_balancer'), l.get('member'))
for l in quotas_desc]
quota_lb_member_asc = [(l.get('load_balancer'), l.get('member'))
for l in quotas_asc]
self.assertEqual(quota_lb_member_asc,
list(reversed(quota_lb_member_desc)))
def test_get_all_limited(self):
self.skipTest("No idea how this should work yet")
# TODO(rm_work): Figure out how to make this ... work
project1_id = uuidutils.generate_uuid()
project2_id = uuidutils.generate_uuid()
project3_id = uuidutils.generate_uuid()
self.create_quota(
project_id=project1_id, lb_quota=3, member_quota=8
).get(self.root_tag)
self.create_quota(
project_id=project2_id, lb_quota=2, member_quota=10
).get(self.root_tag)
self.create_quota(
project_id=project3_id, lb_quota=1, member_quota=9
).get(self.root_tag)
# First two -- should have 'next' link
first_two = self.get(self.QUOTAS_PATH, params={'limit': 2}).json
objs = first_two[self.root_tag_list]
links = first_two[self.root_tag_links]
self.assertEqual(2, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('next', links[0]['rel'])
# Third + off the end -- should have previous link
third = self.get(self.QUOTAS_PATH, params={
'limit': 2,
'marker': first_two[self.root_tag_list][1]['id']}).json
objs = third[self.root_tag_list]
links = third[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(1, len(links))
self.assertEqual('previous', links[0]['rel'])
# Middle -- should have both links
middle = self.get(self.QUOTAS_PATH, params={
'limit': 1,
'marker': first_two[self.root_tag_list][0]['id']}).json
objs = middle[self.root_tag_list]
links = middle[self.root_tag_links]
self.assertEqual(1, len(objs))
self.assertEqual(2, len(links))
self.assertItemsEqual(['previous', 'next'], [l['rel'] for l in links])
def test_get_default_quotas(self): def test_get_default_quotas(self):
response = self.get(self.QUOTA_DEFAULT_PATH.format( response = self.get(self.QUOTA_DEFAULT_PATH.format(
project_id=self.project_id)) project_id=self.project_id))

View File

@ -61,16 +61,16 @@ class BaseRepositoryTest(base.OctaviaDBTestBase):
self.quota_repo = repo.QuotasRepository() self.quota_repo = repo.QuotasRepository()
def test_get_all_return_value(self): def test_get_all_return_value(self):
pool_list = self.pool_repo.get_all(self.session, pool_list, _ = self.pool_repo.get_all(self.session,
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
self.assertIsInstance(pool_list, list) self.assertIsInstance(pool_list, list)
lb_list = self.lb_repo.get_all(self.session, lb_list, _ = self.lb_repo.get_all(self.session,
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
self.assertIsInstance(lb_list, list) self.assertIsInstance(lb_list, list)
listener_list = self.listener_repo.get_all(self.session, listener_list, _ = self.listener_repo.get_all(
project_id=self.FAKE_UUID_2) self.session, project_id=self.FAKE_UUID_2)
self.assertIsInstance(listener_list, list) self.assertIsInstance(listener_list, list)
member_list = self.member_repo.get_all(self.session, member_list, _ = self.member_repo.get_all(self.session,
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
self.assertIsInstance(member_list, list) self.assertIsInstance(member_list, list)
@ -1845,7 +1845,7 @@ class PoolRepositoryTest(BaseRepositoryTest):
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
pool_two = self.create_pool(pool_id=self.FAKE_UUID_3, pool_two = self.create_pool(pool_id=self.FAKE_UUID_3,
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
pool_list = self.pool_repo.get_all(self.session, pool_list, _ = self.pool_repo.get_all(self.session,
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
self.assertIsInstance(pool_list, list) self.assertIsInstance(pool_list, list)
self.assertEqual(2, len(pool_list)) self.assertEqual(2, len(pool_list))
@ -1994,7 +1994,7 @@ class MemberRepositoryTest(BaseRepositoryTest):
self.pool.id, "10.0.0.1") self.pool.id, "10.0.0.1")
member_two = self.create_member(self.FAKE_UUID_3, self.FAKE_UUID_2, member_two = self.create_member(self.FAKE_UUID_3, self.FAKE_UUID_2,
self.pool.id, "10.0.0.2") self.pool.id, "10.0.0.2")
member_list = self.member_repo.get_all(self.session, member_list, _ = self.member_repo.get_all(self.session,
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
self.assertIsInstance(member_list, list) self.assertIsInstance(member_list, list)
self.assertEqual(2, len(member_list)) self.assertEqual(2, len(member_list))
@ -2131,8 +2131,8 @@ class TestListenerRepositoryTest(BaseRepositoryTest):
def test_get_all(self): def test_get_all(self):
listener_one = self.create_listener(self.FAKE_UUID_1, 80) listener_one = self.create_listener(self.FAKE_UUID_1, 80)
listener_two = self.create_listener(self.FAKE_UUID_3, 88) listener_two = self.create_listener(self.FAKE_UUID_3, 88)
listener_list = self.listener_repo.get_all(self.session, listener_list, _ = self.listener_repo.get_all(
project_id=self.FAKE_UUID_2) self.session, project_id=self.FAKE_UUID_2)
self.assertIsInstance(listener_list, list) self.assertIsInstance(listener_list, list)
self.assertEqual(2, len(listener_list)) self.assertEqual(2, len(listener_list))
self.assertEqual(listener_one, listener_list[0]) self.assertEqual(listener_one, listener_list[0])
@ -2530,7 +2530,7 @@ class LoadBalancerRepositoryTest(BaseRepositoryTest):
def test_get_all(self): def test_get_all(self):
lb_one = self.create_loadbalancer(self.FAKE_UUID_1) lb_one = self.create_loadbalancer(self.FAKE_UUID_1)
lb_two = self.create_loadbalancer(self.FAKE_UUID_3) lb_two = self.create_loadbalancer(self.FAKE_UUID_3)
lb_list = self.lb_repo.get_all(self.session, lb_list, _ = self.lb_repo.get_all(self.session,
project_id=self.FAKE_UUID_2) project_id=self.FAKE_UUID_2)
self.assertEqual(2, len(lb_list)) self.assertEqual(2, len(lb_list))
self.assertEqual(lb_one, lb_list[0]) self.assertEqual(lb_one, lb_list[0])
@ -3234,7 +3234,7 @@ class L7PolicyRepositoryTest(BaseRepositoryTest):
self.assertEqual(1, new_l7policy_a.position) self.assertEqual(1, new_l7policy_a.position)
self.assertEqual(2, new_l7policy_b.position) self.assertEqual(2, new_l7policy_b.position)
self.assertEqual(3, new_l7policy_c.position) self.assertEqual(3, new_l7policy_c.position)
l7policy_list = self.l7policy_repo.get_all( l7policy_list, _ = self.l7policy_repo.get_all(
self.session, listener_id=listener.id) self.session, listener_id=listener.id)
self.assertIsInstance(l7policy_list, list) self.assertIsInstance(l7policy_list, list)
self.assertEqual(3, len(l7policy_list)) self.assertEqual(3, len(l7policy_list))
@ -3385,7 +3385,7 @@ class L7PolicyRepositoryTest(BaseRepositoryTest):
self.assertEqual(2, new_l7policy_b.position) self.assertEqual(2, new_l7policy_b.position)
self.assertEqual(3, new_l7policy_c.position) self.assertEqual(3, new_l7policy_c.position)
self.l7policy_repo.delete(self.session, id=l7policy_b.id) self.l7policy_repo.delete(self.session, id=l7policy_b.id)
l7policy_list = self.l7policy_repo.get_all( l7policy_list, _ = self.l7policy_repo.get_all(
self.session, listener_id=listener.id) self.session, listener_id=listener.id)
self.assertIsInstance(l7policy_list, list) self.assertIsInstance(l7policy_list, list)
self.assertEqual(2, len(l7policy_list)) self.assertEqual(2, len(l7policy_list))
@ -3601,7 +3601,7 @@ class L7RuleRepositoryTest(BaseRepositoryTest):
id=l7rule_a.id) id=l7rule_a.id)
new_l7rule_b = self.l7rule_repo.get(self.session, new_l7rule_b = self.l7rule_repo.get(self.session,
id=l7rule_b.id) id=l7rule_b.id)
l7rule_list = self.l7rule_repo.get_all( l7rule_list, _ = self.l7rule_repo.get_all(
self.session, l7policy_id=l7policy.id) self.session, l7policy_id=l7policy.id)
self.assertIsInstance(l7rule_list, list) self.assertIsInstance(l7rule_list, list)
self.assertEqual(2, len(l7rule_list)) self.assertEqual(2, len(l7rule_list))

View File

View File

@ -0,0 +1,171 @@
# Copyright 2016 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from oslo_config import cfg
from oslo_config import fixture as oslo_fixture
from oslo_utils import uuidutils
from octavia.api.common import pagination
from octavia.common import exceptions
from octavia.db import models
from octavia.tests.unit import base
DEFAULT_SORTS = [('created_at', 'asc'), ('id', 'asc')]
class TestPaginationHelper(base.TestCase):
@mock.patch('octavia.api.common.pagination.request')
def test_no_params(self, request_mock):
params = {}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer)
self.assertEqual(DEFAULT_SORTS, helper.sort_keys)
self.assertIsNone(helper.marker)
self.assertEqual(1000, helper.limit)
query_mock.order_by().order_by().limit.assert_called_with(1000)
def test_sort_empty(self):
sort_params = ""
params = {'sort': sort_params}
act_params = pagination.PaginationHelper(
params).sort_keys
self.assertEqual([], act_params)
def test_sort_none(self):
sort_params = None
params = {'sort': sort_params}
act_params = pagination.PaginationHelper(
params).sort_keys
self.assertEqual([], act_params)
def test_sort_key_dir(self):
sort_keys = "key1,key2,key3"
sort_dirs = "asc,desc"
ref_sort_keys = [('key1', 'asc'), ('key2', 'desc'), ('key3', 'asc')]
params = {'sort_key': sort_keys, 'sort_dir': sort_dirs}
helper = pagination.PaginationHelper(params)
self.assertEqual(ref_sort_keys, helper.sort_keys)
def test_invalid_sorts(self):
sort_params = "shoud_fail_exception:cause:of:this"
params = {'sort': sort_params}
self.assertRaises(exceptions.InvalidSortKey,
pagination.PaginationHelper,
params)
sort_params = "ke1:asc,key2:InvalidDir,key3"
params = {'sort': sort_params}
self.assertRaises(exceptions.InvalidSortDirection,
pagination.PaginationHelper,
params)
def test_marker(self):
marker = 'random_uuid'
params = {'marker': marker}
helper = pagination.PaginationHelper(params)
self.assertEqual(marker, helper.marker)
@mock.patch('octavia.api.common.pagination.request')
def test_limit(self, request_mock):
limit = 100
params = {'limit': limit}
helper = pagination.PaginationHelper(params)
query_mock = mock.MagicMock()
helper.apply(query_mock, models.LoadBalancer)
query_mock.order_by().order_by().limit.assert_called_with(limit)
@mock.patch('octavia.api.common.pagination.request')
def test_make_links_next(self, request_mock):
request_mock.path = "/lbaas/v2.0/pools/1/members"
request_mock.path_url = "http://localhost" + request_mock.path
member1 = models.Member()
member1.id = uuidutils.generate_uuid()
model_list = [member1]
params = {'limit': 1}
helper = pagination.PaginationHelper(params)
links = helper._make_links(model_list)
self.assertEqual(links[0].rel, "next")
self.assertEqual(
links[0].href,
"{path_url}?limit={limit}&marker={marker}".format(
path_url=request_mock.path_url,
limit=params['limit'],
marker=member1.id
))
@mock.patch('octavia.api.common.pagination.request')
def test_make_links_prev(self, request_mock):
request_mock.path = "/lbaas/v2.0/pools/1/members"
request_mock.path_url = "http://localhost" + request_mock.path
member1 = models.Member()
member1.id = uuidutils.generate_uuid()
model_list = [member1]
params = {'limit': 1, 'marker': member1.id}
helper = pagination.PaginationHelper(params)
links = helper._make_links(model_list)
self.assertEqual(links[0].rel, "previous")
self.assertEqual(
links[1].href,
"{path_url}?limit={limit}&marker={marker}".format(
path_url=request_mock.path_url,
limit=params['limit'],
marker=member1.id))
self.assertEqual(links[1].rel, "next")
self.assertEqual(
links[1].href,
"{path_url}?limit={limit}&marker={marker}".format(
path_url=request_mock.path_url,
limit=params['limit'],
marker=member1.id))
@mock.patch('octavia.api.common.pagination.request')
def test_make_links_with_configured_url(self, request_mock):
request_mock.path = "/lbaas/v2.0/pools/1/members"
request_mock.path_url = "http://localhost" + request_mock.path
api_base_uri = "https://127.0.0.1"
conf = self.useFixture(oslo_fixture.Config(cfg.CONF))
conf.config(api_base_uri=api_base_uri)
member1 = models.Member()
member1.id = uuidutils.generate_uuid()
model_list = [member1]
params = {'limit': 1, 'marker': member1.id}
helper = pagination.PaginationHelper(params)
links = helper._make_links(model_list)
self.assertEqual(links[0].rel, "previous")
self.assertEqual(
links[1].href,
"{base_uri}{path}?limit={limit}&marker={marker}".format(
base_uri=api_base_uri,
path=request_mock.path,
limit=params['limit'],
marker=member1.id
))
self.assertEqual(links[1].rel, "next")
self.assertEqual(
links[1].href,
"{base_uri}{path}?limit={limit}&marker={marker}".format(
base_uri=api_base_uri,
path=request_mock.path,
limit=params['limit'],
marker=member1.id))

View File

@ -45,7 +45,8 @@ class TestStatsMixin(base.TestCase):
total_connections=random.randrange(1000000000), total_connections=random.randrange(1000000000),
request_errors=random.randrange(1000000000)) request_errors=random.randrange(1000000000))
self.sm.listener_stats_repo.get_all.return_value = [self.fake_stats] self.sm.listener_stats_repo.get_all.return_value = ([self.fake_stats],
None)
self.repo_amphora = mock.MagicMock() self.repo_amphora = mock.MagicMock()
self.sm.repo_amphora = self.repo_amphora self.sm.repo_amphora = self.repo_amphora

View File

@ -108,7 +108,7 @@ class TestDatabaseCleanup(base.TestCase):
lb_network_ip=self.FAKE_IP, lb_network_ip=self.FAKE_IP,
vrrp_ip=self.FAKE_IP, vrrp_ip=self.FAKE_IP,
ha_ip=self.FAKE_IP) ha_ip=self.FAKE_IP)
self.amp_repo.get_all.return_value = [amphora] self.amp_repo.get_all.return_value = ([amphora], None)
self.amp_health_repo.check_amphora_expired.return_value = True self.amp_health_repo.check_amphora_expired.return_value = True
self.dbclean.delete_old_amphorae() self.dbclean.delete_old_amphorae()
self.assertTrue(self.amp_repo.get_all.called) self.assertTrue(self.amp_repo.get_all.called)
@ -126,7 +126,7 @@ class TestDatabaseCleanup(base.TestCase):
lb_network_ip=self.FAKE_IP, lb_network_ip=self.FAKE_IP,
vrrp_ip=self.FAKE_IP, vrrp_ip=self.FAKE_IP,
ha_ip=self.FAKE_IP) ha_ip=self.FAKE_IP)
self.amp_repo.get_all.return_value = [amphora] self.amp_repo.get_all.return_value = ([amphora], None)
self.amp_health_repo.check_amphora_expired.return_value = False self.amp_health_repo.check_amphora_expired.return_value = False
self.dbclean.delete_old_amphorae() self.dbclean.delete_old_amphorae()
self.assertTrue(self.amp_repo.get_all.called) self.assertTrue(self.amp_repo.get_all.called)
@ -146,7 +146,7 @@ class TestDatabaseCleanup(base.TestCase):
for expired_status in [True, False]: for expired_status in [True, False]:
lb_repo = mock.MagicMock() lb_repo = mock.MagicMock()
self.dbclean.lb_repo = lb_repo self.dbclean.lb_repo = lb_repo
lb_repo.get_all.return_value = [load_balancer] lb_repo.get_all.return_value = ([load_balancer], None)
lb_repo.check_load_balancer_expired.return_value = ( lb_repo.check_load_balancer_expired.return_value = (
expired_status) expired_status)
self.dbclean.cleanup_load_balancers() self.dbclean.cleanup_load_balancers()

View File

@ -631,7 +631,7 @@ class TestControllerWorker(base.TestCase):
'LoadBalancerFlows.get_update_load_balancer_flow', 'LoadBalancerFlows.get_update_load_balancer_flow',
return_value=_flow_mock) return_value=_flow_mock)
@mock.patch('octavia.db.repositories.ListenerRepository.get_all', @mock.patch('octavia.db.repositories.ListenerRepository.get_all',
return_value=[_listener_mock]) return_value=([_listener_mock], None))
def test_update_load_balancer(self, def test_update_load_balancer(self,
mock_listener_repo_get_all, mock_listener_repo_get_all,
mock_get_update_lb_flow, mock_get_update_lb_flow,