57f5666f71
Corret spell 'seperated' to 'separated' Change-Id: I18fb80e803e21c01c51d49df9bcb0aae11b91a92
403 lines
17 KiB
Python
403 lines
17 KiB
Python
# Copyright 2016 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import copy
|
|
import itertools
|
|
|
|
from oslo_log import log as logging
|
|
from pecan import request
|
|
import sqlalchemy
|
|
from sqlalchemy.orm import aliased
|
|
import sqlalchemy.sql as sa_sql
|
|
|
|
from octavia.api.common import types
|
|
from octavia.common.config import cfg
|
|
from octavia.common import constants
|
|
from octavia.common import exceptions
|
|
from octavia.db import base_models
|
|
from octavia.db import models
|
|
|
|
CONF = cfg.CONF
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
class PaginationHelper(object):
|
|
"""Class helping to interact with pagination functionality
|
|
|
|
Pass this class to `db.repositories` to apply it on query
|
|
"""
|
|
_auxiliary_arguments = ('limit', 'marker',
|
|
'sort', 'sort_key', 'sort_dir',
|
|
'fields', 'page_reverse',
|
|
)
|
|
|
|
def __init__(self, params, sort_dir=constants.DEFAULT_SORT_DIR):
|
|
"""Pagination Helper takes params and a default sort direction
|
|
|
|
:param params: Contains the following:
|
|
limit: maximum number of items to return
|
|
marker: the last item of the previous page; we return
|
|
the next results after this value.
|
|
sort: array of attr by which results should be sorted
|
|
:param sort_dir: default direction to sort (asc, desc)
|
|
"""
|
|
self.marker = params.get('marker')
|
|
self.sort_dir = self._validate_sort_dir(sort_dir)
|
|
self.limit = self._parse_limit(params)
|
|
self.sort_keys = self._parse_sort_keys(params)
|
|
self.params = params
|
|
self.filters = None
|
|
self.page_reverse = params.get('page_reverse', 'False')
|
|
|
|
@staticmethod
|
|
def _parse_limit(params):
|
|
if CONF.api_settings.pagination_max_limit == 'infinite':
|
|
page_max_limit = None
|
|
else:
|
|
page_max_limit = int(CONF.api_settings.pagination_max_limit)
|
|
limit = params.get('limit', page_max_limit)
|
|
try:
|
|
# Deal with limit being a string or int meaning 'Unlimited'
|
|
if limit == 'infinite' or int(limit) < 1:
|
|
limit = None
|
|
# If we don't have a max, just use whatever limit is specified
|
|
elif page_max_limit is None:
|
|
limit = int(limit)
|
|
# Otherwise, we need to compare against the max
|
|
else:
|
|
limit = min(int(limit), page_max_limit)
|
|
except ValueError as e:
|
|
raise exceptions.InvalidLimit(key=limit) from e
|
|
return limit
|
|
|
|
def _parse_sort_keys(self, params):
|
|
sort_keys_dirs = []
|
|
sort = params.get('sort')
|
|
sort_keys = params.get('sort_key')
|
|
if sort:
|
|
for sort_dir_key in sort.split(","):
|
|
comps = sort_dir_key.split(":")
|
|
if len(comps) == 1: # Use default sort order
|
|
sort_keys_dirs.append((comps[0], self.sort_dir))
|
|
elif len(comps) == 2:
|
|
sort_keys_dirs.append(
|
|
(comps[0], self._validate_sort_dir(comps[1])))
|
|
else:
|
|
raise exceptions.InvalidSortKey(key=comps)
|
|
elif sort_keys:
|
|
sort_keys = sort_keys.split(',')
|
|
sort_dirs = params.get('sort_dir')
|
|
if not sort_dirs:
|
|
sort_dirs = [self.sort_dir] * len(sort_keys)
|
|
else:
|
|
sort_dirs = sort_dirs.split(',')
|
|
|
|
if len(sort_dirs) < len(sort_keys):
|
|
sort_dirs += [self.sort_dir] * (len(sort_keys) -
|
|
len(sort_dirs))
|
|
for sk, sd in zip(sort_keys, sort_dirs):
|
|
sort_keys_dirs.append((sk, self._validate_sort_dir(sd)))
|
|
|
|
return sort_keys_dirs
|
|
|
|
def _parse_marker(self, session, model):
|
|
return session.query(model).filter_by(id=self.marker).one_or_none()
|
|
|
|
@staticmethod
|
|
def _get_default_column_value(column_type):
|
|
"""Return the default value of the columns from DB table
|
|
|
|
In postgreDB case, if no right default values are being set, an
|
|
psycopg2.DataError will be thrown.
|
|
"""
|
|
type_schema = {
|
|
'datetime': None,
|
|
'big_integer': 0,
|
|
'integer': 0,
|
|
'string': ''
|
|
}
|
|
|
|
if isinstance(column_type, sa_sql.type_api.Variant):
|
|
return PaginationHelper._get_default_column_value(column_type.impl)
|
|
|
|
return type_schema[column_type.__visit_name__]
|
|
|
|
@staticmethod
|
|
def _validate_sort_dir(sort_dir):
|
|
sort_dir = sort_dir.lower()
|
|
if sort_dir not in constants.ALLOWED_SORT_DIR:
|
|
raise exceptions.InvalidSortDirection(key=sort_dir)
|
|
return sort_dir
|
|
|
|
def _make_links(self, model_list):
|
|
if CONF.api_settings.api_base_uri:
|
|
path_url = "{api_base_url}{path}".format(
|
|
api_base_url=CONF.api_settings.api_base_uri.rstrip('/'),
|
|
path=request.path)
|
|
else:
|
|
path_url = request.path_url
|
|
links = []
|
|
if model_list:
|
|
prev_attr = ["limit={}".format(self.limit)]
|
|
if self.params.get('sort'):
|
|
prev_attr.append("sort={}".format(self.params.get('sort')))
|
|
if self.params.get('sort_key'):
|
|
prev_attr.append("sort_key={}".format(
|
|
self.params.get('sort_key')))
|
|
next_attr = copy.copy(prev_attr)
|
|
if self.marker:
|
|
prev_attr.append("marker={}".format(model_list[0].get('id')))
|
|
prev_attr.append("page_reverse=True")
|
|
prev_link = {
|
|
"rel": "previous",
|
|
"href": "{url}?{params}".format(
|
|
url=path_url,
|
|
params="&".join(prev_attr))
|
|
}
|
|
links.append(prev_link)
|
|
# TODO(rm_work) Do we need to know when there are more vs exact?
|
|
# We safely know if we have a full page, but it might include the
|
|
# last element or it might not, it is unclear
|
|
if len(model_list) >= self.limit:
|
|
next_attr.append("marker={}".format(model_list[-1].get('id')))
|
|
next_link = {
|
|
"rel": "next",
|
|
"href": "{url}?{params}".format(
|
|
url=path_url,
|
|
params="&".join(next_attr))
|
|
}
|
|
links.append(next_link)
|
|
links = [types.PageType(**link) for link in links]
|
|
return links
|
|
|
|
def _apply_tags_filtering(self, params, model, query):
|
|
if not getattr(model, "_tags", None):
|
|
return query
|
|
|
|
if 'tags' in params:
|
|
tags = params.pop('tags')
|
|
|
|
for tag in tags:
|
|
# This requires a multi-join to the tags table,
|
|
# so me must use aliases for each one.
|
|
tag_alias = aliased(base_models.Tags)
|
|
query = query.join(tag_alias, model._tags)
|
|
query = query.filter(tag_alias.tag == tag)
|
|
|
|
if 'tags-any' in params:
|
|
tags = params.pop('tags-any')
|
|
tag_alias = aliased(base_models.Tags)
|
|
query = query.join(tag_alias, model._tags)
|
|
query = query.filter(tag_alias.tag.in_(tags))
|
|
|
|
if 'not-tags' in params:
|
|
tags = params.pop('not-tags')
|
|
subq = query.session.query(model.id)
|
|
for tag in tags:
|
|
tag_alias = aliased(base_models.Tags)
|
|
subq = subq.join(tag_alias, model._tags)
|
|
subq = subq.filter(tag_alias.tag == tag)
|
|
|
|
query = query.filter(~model.id.in_(subq))
|
|
|
|
if 'not-tags-any' in params:
|
|
tags = params.pop('not-tags-any')
|
|
query = query.filter(
|
|
~model._tags.any(base_models.Tags.tag.in_(tags)))
|
|
|
|
return query
|
|
|
|
@staticmethod
|
|
def _prepare_tags_list(param):
|
|
"""Split comma separated tags and return a flat list of tags."""
|
|
if not isinstance(param, list):
|
|
param = [param]
|
|
return list(itertools.chain.from_iterable(
|
|
tag.split(',') for tag in param))
|
|
|
|
def apply(self, query, model, enforce_valid_params=True):
|
|
"""Returns a query with sorting / pagination criteria added.
|
|
|
|
Pagination works by requiring a unique sort_key specified by sort_keys.
|
|
(If sort_keys is not unique, then we risk looping through values.)
|
|
We use the last row in the previous page as the pagination 'marker'.
|
|
So we must return values that follow the passed marker in the order.
|
|
With a single-valued sort_key, this would be easy: sort_key > X.
|
|
With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
|
|
the lexicographical ordering:
|
|
(k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
|
|
We also have to cope with different sort_directions.
|
|
Typically, the id of the last row is used as the client-facing
|
|
pagination marker, then the actual marker object must be fetched from
|
|
the db and passed in to us as marker.
|
|
:param query: the query object to which we should add
|
|
paging/sorting/filtering
|
|
:param model: the ORM model class
|
|
:param enforce_valid_params: check for invalid enteries in self.params
|
|
|
|
:rtype: sqlalchemy.orm.query.Query
|
|
:returns: The query with sorting/pagination/filtering added.
|
|
"""
|
|
|
|
# Add filtering
|
|
if CONF.api_settings.allow_filtering:
|
|
# Exclude (valid) arguments that are not used for data filtering
|
|
filter_params = {k: v for k, v in self.params.items()
|
|
if k not in self._auxiliary_arguments}
|
|
|
|
secondary_query_filter = filter_params.pop(
|
|
"project_id", None) if (model == models.Amphora) else None
|
|
|
|
# Tranlate arguments from API standard to data model's field name
|
|
filter_params = (
|
|
model.__v2_wsme__.translate_dict_keys_to_data_model(
|
|
filter_params)
|
|
)
|
|
if 'loadbalancer_id' in filter_params:
|
|
filter_params['load_balancer_id'] = filter_params.pop(
|
|
'loadbalancer_id')
|
|
|
|
# Pop the 'tags' related parameters off before handling the
|
|
# other filters. Then apply the 'tags' filters after the
|
|
# other filters have been applied.
|
|
tag_params = {}
|
|
if 'tags' in filter_params:
|
|
tag_params['tags'] = self._prepare_tags_list(
|
|
filter_params.pop('tags'))
|
|
if 'tags-any' in filter_params:
|
|
tag_params['tags-any'] = self._prepare_tags_list(
|
|
filter_params.pop('tags-any'))
|
|
if 'not-tags' in filter_params:
|
|
tag_params['not-tags'] = self._prepare_tags_list(
|
|
filter_params.pop('not-tags'))
|
|
if 'not-tags-any' in filter_params:
|
|
tag_params['not-tags-any'] = self._prepare_tags_list(
|
|
filter_params.pop('not-tags-any'))
|
|
|
|
# Drop invalid arguments
|
|
self.filters = {k: v for (k, v) in filter_params.items()
|
|
if k in vars(model.__data_model__())}
|
|
|
|
if enforce_valid_params and (
|
|
len(self.filters) < len(filter_params)
|
|
):
|
|
raise exceptions.InvalidFilterArgument()
|
|
|
|
query = model.apply_filter(query, model, self.filters)
|
|
if secondary_query_filter is not None:
|
|
query = query.filter(model.load_balancer.has(
|
|
project_id=secondary_query_filter))
|
|
|
|
# Apply tags filtering for the models which support tags.
|
|
query = self._apply_tags_filtering(tag_params, model, query)
|
|
|
|
# Add sorting
|
|
if CONF.api_settings.allow_sorting:
|
|
# Add default sort keys (if they are OK for the model)
|
|
keys_only = [k[0] for k in self.sort_keys]
|
|
for key in constants.DEFAULT_SORT_KEYS:
|
|
if key not in keys_only and hasattr(model, key):
|
|
self.sort_keys.append((key, self.sort_dir))
|
|
|
|
for current_sort_key, current_sort_dir in self.sort_keys:
|
|
# Translate sort_key from API standard to data model's name
|
|
current_sort_key = (
|
|
model.__v2_wsme__.translate_key_to_data_model(
|
|
current_sort_key))
|
|
sort_dir_func = {
|
|
constants.ASC: sqlalchemy.asc,
|
|
constants.DESC: sqlalchemy.desc,
|
|
}[current_sort_dir]
|
|
|
|
try:
|
|
# The translated object may be a nested parameter
|
|
# such as vip.ip_address, so handle that case by
|
|
# joining with the nested table.
|
|
if '.' in current_sort_key:
|
|
parent, child = current_sort_key.split('.')
|
|
parent_obj = getattr(model, parent)
|
|
query = query.join(parent_obj)
|
|
sort_key_attr = child
|
|
else:
|
|
sort_key_attr = getattr(model, current_sort_key)
|
|
except AttributeError as e:
|
|
raise exceptions.InvalidSortKey(
|
|
key=current_sort_key) from e
|
|
query = query.order_by(sort_dir_func(sort_key_attr))
|
|
|
|
# Add pagination
|
|
if CONF.api_settings.allow_pagination:
|
|
default = '' # Default to an empty string if NULL
|
|
if self.marker is not None:
|
|
marker_object = self._parse_marker(query.session, model)
|
|
if not marker_object:
|
|
raise exceptions.InvalidMarker(key=self.marker)
|
|
marker_values = []
|
|
for sort_key, _ in self.sort_keys:
|
|
v = getattr(marker_object, sort_key)
|
|
if v is None:
|
|
v = default
|
|
marker_values.append(v)
|
|
|
|
# Build up an array of sort criteria as in the docstring
|
|
criteria_list = []
|
|
for i in range(len(self.sort_keys)):
|
|
crit_attrs = []
|
|
for j in range(i):
|
|
model_attr = getattr(model, self.sort_keys[j][0])
|
|
default = PaginationHelper._get_default_column_value(
|
|
model_attr.property.columns[0].type)
|
|
attr = sa_sql.expression.case(
|
|
[(model_attr.isnot(None),
|
|
model_attr), ], else_=default)
|
|
crit_attrs.append((attr == marker_values[j]))
|
|
|
|
model_attr = getattr(model, self.sort_keys[i][0])
|
|
default = PaginationHelper._get_default_column_value(
|
|
model_attr.property.columns[0].type)
|
|
attr = sa_sql.expression.case(
|
|
[(model_attr.isnot(None),
|
|
model_attr), ], else_=default)
|
|
this_sort_dir = self.sort_keys[i][1]
|
|
if this_sort_dir == constants.DESC:
|
|
if self.page_reverse == "True":
|
|
crit_attrs.append((attr > marker_values[i]))
|
|
else:
|
|
crit_attrs.append((attr < marker_values[i]))
|
|
elif this_sort_dir == constants.ASC:
|
|
if self.page_reverse == "True":
|
|
crit_attrs.append((attr < marker_values[i]))
|
|
else:
|
|
crit_attrs.append((attr > marker_values[i]))
|
|
else:
|
|
raise exceptions.InvalidSortDirection(
|
|
key=this_sort_dir)
|
|
|
|
criteria = sa_sql.and_(*crit_attrs)
|
|
criteria_list.append(criteria)
|
|
|
|
f = sa_sql.or_(*criteria_list)
|
|
query = query.filter(f)
|
|
|
|
if self.limit is not None:
|
|
query = query.limit(self.limit)
|
|
|
|
model_list = query.all()
|
|
|
|
links = None
|
|
if CONF.api_settings.allow_pagination:
|
|
links = self._make_links(model_list)
|
|
|
|
return model_list, links
|