Merge "Adding marker, pagination, sort key and sort direction to v2 api"

This commit is contained in:
Jenkins
2013-01-03 05:13:59 +00:00
committed by Gerrit Code Review
9 changed files with 294 additions and 40 deletions

View File

@@ -260,7 +260,9 @@ class VolumeController(wsgi.Controller):
remove_invalid_options(context,
search_opts, self._get_volume_search_options())
volumes = self.volume_api.get_all(context, search_opts=search_opts)
volumes = self.volume_api.get_all(context, marker=None, limit=None,
sort_key='created_at',
sort_dir='desc', filters=search_opts)
limited_list = common.limited(volumes, req)
res = [entity_maker(context, vol) for vol in limited_list]
return {'volumes': res}

View File

@@ -171,20 +171,27 @@ class VolumeController(wsgi.Controller):
def _get_volumes(self, req, is_detail):
"""Returns a list of volumes, transformed through view builder."""
search_opts = {}
search_opts.update(req.GET)
context = req.environ['cinder.context']
params = req.params.copy()
marker = params.pop('marker', None)
limit = params.pop('limit', None)
sort_key = params.pop('sort_key', 'created_at')
sort_dir = params.pop('sort_dir', 'desc')
filters = params
remove_invalid_options(context,
search_opts, self._get_volume_search_options())
filters, self._get_volume_filter_options())
# NOTE(thingee): v2 API allows name instead of display_name
if 'name' in search_opts:
search_opts['display_name'] = search_opts['name']
del search_opts['name']
if 'name' in filters:
filters['display_name'] = filters['name']
del filters['name']
volumes = self.volume_api.get_all(context, search_opts=search_opts)
volumes = self.volume_api.get_all(context, marker, limit, sort_key,
sort_dir, filters)
limited_list = common.limited(volumes, req)
if is_detail:
volumes = self._view_builder.detail_list(req, limited_list)
else:
@@ -273,7 +280,7 @@ class VolumeController(wsgi.Controller):
return retval
def _get_volume_search_options(self):
def _get_volume_filter_options(self):
"""Return volume search options allowed by non-admin."""
return ('name', 'status')
@@ -321,16 +328,16 @@ def create_resource(ext_mgr):
return wsgi.Resource(VolumeController(ext_mgr))
def remove_invalid_options(context, search_options, allowed_search_options):
def remove_invalid_options(context, filters, allowed_search_options):
"""Remove search options that are not valid for non-admin API/context."""
if context.is_admin:
# Allow all options
return
# Otherwise, strip out all unknown options
unknown_options = [opt for opt in search_options
unknown_options = [opt for opt in filters
if opt not in allowed_search_options]
bad_options = ", ".join(unknown_options)
log_msg = _("Removing options '%(bad_options)s' from query") % locals()
log_msg = _("Removing options '%s' from query") % bad_options
LOG.debug(log_msg)
for opt in unknown_options:
del search_options[opt]
del filters[opt]

128
cinder/common/sqlalchemyutils.py Executable file
View File

@@ -0,0 +1,128 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2010-2011 OpenStack LLC.
# Copyright 2012 Justin Santa Barbara
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Implementation of paginate query."""
import sqlalchemy
from cinder import exception
from cinder.openstack.common import log as logging
LOG = logging.getLogger(__name__)
# copied from glance/db/sqlalchemy/api.py
def paginate_query(query, model, limit, sort_keys, marker=None,
sort_dir=None, sort_dirs=None):
"""Returns a query with sorting / pagination criteria added.
Pagination works by requiring a unique sort_key, specified by sort_keys.
(If sort_keys is not unique, then we risk looping through values.)
We use the last row in the previous page as the 'marker' for pagination.
So we must return values that follow the passed marker in the order.
With a single-valued sort_key, this would be easy: sort_key > X.
With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
the lexicographical ordering:
(k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
We also have to cope with different sort_directions.
Typically, the id of the last row is used as the client-facing pagination
marker, then the actual marker object must be fetched from the db and
passed in to us as marker.
:param query: the query object to which we should add paging/sorting
:param model: the ORM model class
:param limit: maximum number of items to return
:param sort_keys: array of attributes by which results should be sorted
:param marker: the last item of the previous page; we returns the next
results after this value.
:param sort_dir: direction in which results should be sorted (asc, desc)
:param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys
:rtype: sqlalchemy.orm.query.Query
:return: The query with sorting/pagination added.
"""
if 'id' not in sort_keys:
# TODO(justinsb): If this ever gives a false-positive, check
# the actual primary key, rather than assuming its id
LOG.warn(_('Id not in sort_keys; is sort_keys unique?'))
assert(not (sort_dir and sort_dirs))
# Default the sort direction to ascending
if sort_dirs is None and sort_dir is None:
sort_dir = 'asc'
# Ensure a per-column sort direction
if sort_dirs is None:
sort_dirs = [sort_dir for _sort_key in sort_keys]
assert(len(sort_dirs) == len(sort_keys))
# Add sorting
for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs):
sort_dir_func = {
'asc': sqlalchemy.asc,
'desc': sqlalchemy.desc,
}[current_sort_dir]
try:
sort_key_attr = getattr(model, current_sort_key)
except AttributeError:
raise exception.InvalidInput(reason='Invalid sort key')
query = query.order_by(sort_dir_func(sort_key_attr))
# Add pagination
if marker is not None:
marker_values = []
for sort_key in sort_keys:
v = getattr(marker, sort_key)
marker_values.append(v)
# Build up an array of sort criteria as in the docstring
criteria_list = []
for i in xrange(0, len(sort_keys)):
crit_attrs = []
for j in xrange(0, i):
model_attr = getattr(model, sort_keys[j])
crit_attrs.append((model_attr == marker_values[j]))
model_attr = getattr(model, sort_keys[i])
if sort_dirs[i] == 'desc':
crit_attrs.append((model_attr < marker_values[i]))
elif sort_dirs[i] == 'asc':
crit_attrs.append((model_attr > marker_values[i]))
else:
raise ValueError(_("Unknown sort direction, "
"must be 'desc' or 'asc'"))
criteria = sqlalchemy.sql.and_(*crit_attrs)
criteria_list.append(criteria)
f = sqlalchemy.sql.or_(*criteria_list)
query = query.filter(f)
if limit is not None:
query = query.limit(limit)
return query

View File

@@ -229,9 +229,9 @@ def volume_get(context, volume_id):
return IMPL.volume_get(context, volume_id)
def volume_get_all(context):
def volume_get_all(context, marker, limit, sort_key, sort_dir):
"""Get all volumes."""
return IMPL.volume_get_all(context)
return IMPL.volume_get_all(context, marker, limit, sort_key, sort_dir)
def volume_get_all_by_host(context, host):
@@ -244,9 +244,11 @@ def volume_get_all_by_instance_uuid(context, instance_uuid):
return IMPL.volume_get_all_by_instance_uuid(context, instance_uuid)
def volume_get_all_by_project(context, project_id):
def volume_get_all_by_project(context, project_id, marker, limit, sort_key,
sort_dir):
"""Get all volumes belonging to a project."""
return IMPL.volume_get_all_by_project(context, project_id)
return IMPL.volume_get_all_by_project(context, project_id, marker, limit,
sort_key, sort_dir)
def volume_get_iscsi_target_num(context, volume_id):

View File

@@ -29,6 +29,7 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql import func
from cinder.common import sqlalchemyutils
from cinder import db
from cinder.db.sqlalchemy import models
from cinder.db.sqlalchemy.session import get_session
@@ -1022,8 +1023,19 @@ def volume_get(context, volume_id, session=None):
@require_admin_context
def volume_get_all(context):
return _volume_get_query(context).all()
def volume_get_all(context, marker, limit, sort_key, sort_dir):
query = _volume_get_query(context)
marker_volume = None
if marker is not None:
marker_volume = volume_get(context, marker)
query = sqlalchemyutils.paginate_query(query, models.Volume, limit,
[sort_key, 'created_at', 'id'],
marker=marker_volume,
sort_dir=sort_dir)
return query.all()
@require_admin_context
@@ -1046,9 +1058,21 @@ def volume_get_all_by_instance_uuid(context, instance_uuid):
@require_context
def volume_get_all_by_project(context, project_id):
def volume_get_all_by_project(context, project_id, marker, limit, sort_key,
sort_dir):
authorize_project_context(context, project_id)
return _volume_get_query(context).filter_by(project_id=project_id).all()
query = _volume_get_query(context).filter_by(project_id=project_id)
marker_volume = None
if marker is not None:
marker_volume = volume_get(context, marker)
query = sqlalchemyutils.paginate_query(query, models.Volume, limit,
[sort_key, 'created_at', 'id'],
marker=marker_volume,
sort_dir=sort_dir)
return query.all()
@require_admin_context

View File

@@ -331,7 +331,8 @@ class VolumeApiTest(test.TestCase):
self.assertEqual(res_dict, expected)
def test_volume_list_by_name(self):
def stub_volume_get_all_by_project(context, project_id):
def stub_volume_get_all_by_project(context, project_id, marker, limit,
sort_key, sort_dir):
return [
stubs.stub_volume(1, display_name='vol1'),
stubs.stub_volume(2, display_name='vol2'),
@@ -355,7 +356,8 @@ class VolumeApiTest(test.TestCase):
self.assertEqual(len(resp['volumes']), 0)
def test_volume_list_by_status(self):
def stub_volume_get_all_by_project(context, project_id):
def stub_volume_get_all_by_project(context, project_id, marker, limit,
sort_key, sort_dir):
return [
stubs.stub_volume(1, display_name='vol1', status='available'),
stubs.stub_volume(2, display_name='vol2', status='available'),

View File

@@ -91,13 +91,15 @@ def stub_volume_get_notfound(self, context, volume_id):
raise exc.NotFound
def stub_volume_get_all(context, search_opts=None):
def stub_volume_get_all(context, search_opts=None, marker=None, limit=None,
sort_key='created_at', sort_dir='desc'):
return [stub_volume(100, project_id='fake'),
stub_volume(101, project_id='superfake'),
stub_volume(102, project_id='superduperfake')]
def stub_volume_get_all_by_project(self, context, search_opts=None):
def stub_volume_get_all_by_project(self, context, marker, limit, sort_key,
sort_dir, filters={}):
return [stub_volume_get(self, context, '1')]

View File

@@ -392,8 +392,91 @@ class VolumeApiTest(test.TestCase):
}
self.assertEqual(res_dict, expected)
def test_volume_index_with_marker(self):
def stub_volume_get_all_by_project(context, project_id, marker, limit,
sort_key, sort_dir):
return [
stubs.stub_volume(1, display_name='vol1'),
stubs.stub_volume(2, display_name='vol2'),
]
self.stubs.Set(db, 'volume_get_all_by_project',
stub_volume_get_all_by_project)
req = fakes.HTTPRequest.blank('/v2/volumes?marker=1')
res_dict = self.controller.index(req)
volumes = res_dict['volumes']
self.assertEquals(len(volumes), 2)
self.assertEquals(volumes[0]['id'], 1)
self.assertEquals(volumes[1]['id'], 2)
def test_volume_index_limit(self):
req = fakes.HTTPRequest.blank('/v2/volumes?limit=1')
res_dict = self.controller.index(req)
volumes = res_dict['volumes']
self.assertEquals(len(volumes), 1)
def test_volume_index_limit_negative(self):
req = fakes.HTTPRequest.blank('/v2/volumes?limit=-1')
self.assertRaises(webob.exc.HTTPBadRequest,
self.controller.index,
req)
def test_volume_index_limit_non_int(self):
req = fakes.HTTPRequest.blank('/v2/volumes?limit=a')
self.assertRaises(webob.exc.HTTPBadRequest,
self.controller.index,
req)
def test_volume_index_limit_marker(self):
req = fakes.HTTPRequest.blank('/v2/volumes?marker=1&limit=1')
res_dict = self.controller.index(req)
volumes = res_dict['volumes']
self.assertEquals(len(volumes), 1)
self.assertEquals(volumes[0]['id'], '1')
def test_volume_detail_with_marker(self):
def stub_volume_get_all_by_project(context, project_id, marker, limit,
sort_key, sort_dir):
return [
stubs.stub_volume(1, display_name='vol1'),
stubs.stub_volume(2, display_name='vol2'),
]
self.stubs.Set(db, 'volume_get_all_by_project',
stub_volume_get_all_by_project)
req = fakes.HTTPRequest.blank('/v2/volumes/detail?marker=1')
res_dict = self.controller.index(req)
volumes = res_dict['volumes']
self.assertEquals(len(volumes), 2)
self.assertEquals(volumes[0]['id'], 1)
self.assertEquals(volumes[1]['id'], 2)
def test_volume_detail_limit(self):
req = fakes.HTTPRequest.blank('/v2/volumes/detail?limit=1')
res_dict = self.controller.index(req)
volumes = res_dict['volumes']
self.assertEquals(len(volumes), 1)
def test_volume_detail_limit_negative(self):
req = fakes.HTTPRequest.blank('/v2/volumes/detail?limit=-1')
self.assertRaises(webob.exc.HTTPBadRequest,
self.controller.index,
req)
def test_volume_detail_limit_non_int(self):
req = fakes.HTTPRequest.blank('/v2/volumes/detail?limit=a')
self.assertRaises(webob.exc.HTTPBadRequest,
self.controller.index,
req)
def test_volume_detail_limit_marker(self):
req = fakes.HTTPRequest.blank('/v2/volumes/detail?marker=1&limit=1')
res_dict = self.controller.index(req)
volumes = res_dict['volumes']
self.assertEquals(len(volumes), 1)
self.assertEquals(volumes[0]['id'], '1')
def test_volume_list_by_name(self):
def stub_volume_get_all_by_project(context, project_id):
def stub_volume_get_all_by_project(context, project_id, marker, limit,
sort_key, sort_dir):
return [
stubs.stub_volume(1, display_name='vol1'),
stubs.stub_volume(2, display_name='vol2'),
@@ -408,7 +491,6 @@ class VolumeApiTest(test.TestCase):
self.assertEqual(len(resp['volumes']), 3)
# filter on name
req = fakes.HTTPRequest.blank('/v2/volumes?name=vol2')
#import pdb; pdb.set_trace()
resp = self.controller.index(req)
self.assertEqual(len(resp['volumes']), 1)
self.assertEqual(resp['volumes'][0]['name'], 'vol2')
@@ -418,7 +500,8 @@ class VolumeApiTest(test.TestCase):
self.assertEqual(len(resp['volumes']), 0)
def test_volume_list_by_status(self):
def stub_volume_get_all_by_project(context, project_id):
def stub_volume_get_all_by_project(context, project_id, marker, limit,
sort_key, sort_dir):
return [
stubs.stub_volume(1, display_name='vol1', status='available'),
stubs.stub_volume(2, display_name='vol2', status='available'),

View File

@@ -23,6 +23,7 @@ Handles all requests relating to volumes.
import functools
from cinder.db import base
from cinder.db.sqlalchemy import models
from cinder import exception
from cinder import flags
from cinder.image import glance
@@ -267,21 +268,23 @@ class API(base.Base):
check_policy(context, 'get', volume)
return volume
def get_all(self, context, search_opts=None):
def get_all(self, context, marker=None, limit=None, sort_key='created_at',
sort_dir='desc', filters={}):
check_policy(context, 'get_all')
if search_opts is None:
search_opts = {}
if (context.is_admin and 'all_tenants' in search_opts):
if (context.is_admin and 'all_tenants' in filters):
# Need to remove all_tenants to pass the filtering below.
del search_opts['all_tenants']
volumes = self.db.volume_get_all(context)
del filters['all_tenants']
volumes = self.db.volume_get_all(context, marker, limit, sort_key,
sort_dir)
else:
volumes = self.db.volume_get_all_by_project(context,
context.project_id)
if search_opts:
LOG.debug(_("Searching by: %s") % str(search_opts))
context.project_id,
marker, limit,
sort_key, sort_dir)
if filters:
LOG.debug(_("Searching by: %s") % str(filters))
def _check_metadata_match(volume, searchdict):
volume_metadata = {}
@@ -301,7 +304,7 @@ class API(base.Base):
not_found = object()
for volume in volumes:
# go over all filters in the list
for opt, values in search_opts.iteritems():
for opt, values in filters.iteritems():
try:
filter_func = filter_mapping[opt]
except KeyError:
@@ -312,6 +315,7 @@ class API(base.Base):
else: # did not break out loop
result.append(volume) # volume matches all filters
volumes = result
return volumes
def get_snapshot(self, context, snapshot_id):