Make sure sort_key_attr is QueryableAttribute when query
When doing query.order_by, sort_key_attr is get from model class, we need to make sure sort_key_attr is really a QueryableAttribute type instance before we do the query or it will cause errors. This will prevent if there IS a function which name is same as sort_key. Closes-Bug: 1405069 Change-Id: I8a3eb08ab3469ec08e05bfce754b664943d65c83
This commit is contained in:
parent
98b434db7d
commit
bce8ed3042
oslo_db
@ -32,6 +32,7 @@ from sqlalchemy.engine import url as sa_url
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.sql.expression import literal_column
|
||||
@ -147,8 +148,9 @@ def paginate_query(query, model, limit, sort_keys, marker=None,
|
||||
raise ValueError(_("Unknown sort direction, "
|
||||
"must be 'desc' or 'asc'"))
|
||||
try:
|
||||
sort_key_attr = getattr(model, current_sort_key)
|
||||
except AttributeError:
|
||||
sort_key_attr = inspect(model).\
|
||||
all_orm_descriptors[current_sort_key]
|
||||
except KeyError:
|
||||
raise exception.InvalidSortKey()
|
||||
query = query.order_by(sort_dir_func(sort_key_attr))
|
||||
|
||||
|
@ -43,6 +43,7 @@ from oslo_db.tests.old_import_api import utils as test_utils
|
||||
|
||||
|
||||
SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.')))
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class TestSanitizeDbUrl(test_base.BaseTestCase):
|
||||
@ -86,6 +87,17 @@ class FakeModel(object):
|
||||
return '<FakeModel: %s>' % self.values
|
||||
|
||||
|
||||
class FakeTable(Base):
|
||||
__tablename__ = 'fake_table'
|
||||
|
||||
user_id = Column(String(50), primary_key=True)
|
||||
project_id = Column(String(50))
|
||||
snapshot_id = Column(String(50))
|
||||
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestPaginateQuery(test_base.BaseTestCase):
|
||||
def setUp(self):
|
||||
super(TestPaginateQuery, self).setUp()
|
||||
@ -94,23 +106,17 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
self.query = self.mox.CreateMockAnything()
|
||||
self.mox.StubOutWithMock(sqlalchemy, 'asc')
|
||||
self.mox.StubOutWithMock(sqlalchemy, 'desc')
|
||||
self.marker = FakeModel({
|
||||
'user_id': 'user',
|
||||
'project_id': 'p',
|
||||
'snapshot_id': 's',
|
||||
})
|
||||
self.model = FakeModel({
|
||||
'user_id': 'user',
|
||||
'project_id': 'project',
|
||||
'snapshot_id': 'snapshot',
|
||||
})
|
||||
self.marker = FakeTable(user_id='user',
|
||||
project_id='p',
|
||||
snapshot_id='s')
|
||||
self.model = FakeTable
|
||||
|
||||
def test_paginate_query_no_pagination_no_sort_dirs(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc_3')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc_3')
|
||||
self.query.order_by('asc_3').AndReturn(self.query)
|
||||
sqlalchemy.asc('project').AndReturn('asc_2')
|
||||
sqlalchemy.asc(self.model.project_id).AndReturn('asc_2')
|
||||
self.query.order_by('asc_2').AndReturn(self.query)
|
||||
sqlalchemy.asc('snapshot').AndReturn('asc_1')
|
||||
sqlalchemy.asc(self.model.snapshot_id).AndReturn('asc_1')
|
||||
self.query.order_by('asc_1').AndReturn(self.query)
|
||||
self.query.limit(5).AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
@ -118,9 +124,9 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
['user_id', 'project_id', 'snapshot_id'])
|
||||
|
||||
def test_paginate_query_no_pagination(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc')
|
||||
self.query.order_by('asc').AndReturn(self.query)
|
||||
sqlalchemy.desc('project').AndReturn('desc')
|
||||
sqlalchemy.desc(self.model.project_id).AndReturn('desc')
|
||||
self.query.order_by('desc').AndReturn(self.query)
|
||||
self.query.limit(5).AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
@ -129,13 +135,23 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
sort_dirs=['asc', 'desc'])
|
||||
|
||||
def test_paginate_query_attribute_error(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc')
|
||||
self.query.order_by('asc').AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
self.assertRaises(exception.InvalidSortKey,
|
||||
utils.paginate_query, self.query,
|
||||
self.model, 5, ['user_id', 'non-existent key'])
|
||||
|
||||
def test_paginate_query_attribute_error_invalid_sortkey(self):
|
||||
self.assertRaises(exception.InvalidSortKey,
|
||||
utils.paginate_query, self.query,
|
||||
self.model, 5, ['bad_user_id'])
|
||||
|
||||
def test_paginate_query_attribute_error_invalid_sortkey_2(self):
|
||||
self.assertRaises(exception.InvalidSortKey,
|
||||
utils.paginate_query, self.query,
|
||||
self.model, 5, ['foo'])
|
||||
|
||||
def test_paginate_query_assertion_error(self):
|
||||
self.mox.ReplayAll()
|
||||
self.assertRaises(AssertionError,
|
||||
@ -153,13 +169,13 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
sort_dir=None, sort_dirs=['asc', 'desk'])
|
||||
|
||||
def test_paginate_query(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc_1')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
|
||||
self.query.order_by('asc_1').AndReturn(self.query)
|
||||
sqlalchemy.desc('project').AndReturn('desc_1')
|
||||
sqlalchemy.desc(self.model.project_id).AndReturn('desc_1')
|
||||
self.query.order_by('desc_1').AndReturn(self.query)
|
||||
self.mox.StubOutWithMock(sqlalchemy.sql, 'and_')
|
||||
sqlalchemy.sql.and_(False).AndReturn('some_crit')
|
||||
sqlalchemy.sql.and_(True, False).AndReturn('another_crit')
|
||||
sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit')
|
||||
sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit')
|
||||
self.mox.StubOutWithMock(sqlalchemy.sql, 'or_')
|
||||
sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f')
|
||||
self.query.filter('some_f').AndReturn(self.query)
|
||||
@ -171,7 +187,7 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
sort_dirs=['asc', 'desc'])
|
||||
|
||||
def test_paginate_query_value_error(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc_1')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
|
||||
self.query.order_by('asc_1').AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
self.assertRaises(ValueError, utils.paginate_query,
|
||||
|
@ -41,6 +41,7 @@ from oslo_db.sqlalchemy import utils
|
||||
from oslo_db.tests import utils as test_utils
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
SA_VERSION = tuple(map(int, sqlalchemy.__version__.split('.')))
|
||||
|
||||
|
||||
@ -64,6 +65,17 @@ class CustomType(UserDefinedType):
|
||||
return "CustomType"
|
||||
|
||||
|
||||
class FakeTable(Base):
|
||||
__tablename__ = 'fake_table'
|
||||
|
||||
user_id = Column(String(50), primary_key=True)
|
||||
project_id = Column(String(50))
|
||||
snapshot_id = Column(String(50))
|
||||
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeModel(object):
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
@ -93,23 +105,17 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
self.query = self.mox.CreateMockAnything()
|
||||
self.mox.StubOutWithMock(sqlalchemy, 'asc')
|
||||
self.mox.StubOutWithMock(sqlalchemy, 'desc')
|
||||
self.marker = FakeModel({
|
||||
'user_id': 'user',
|
||||
'project_id': 'p',
|
||||
'snapshot_id': 's',
|
||||
})
|
||||
self.model = FakeModel({
|
||||
'user_id': 'user',
|
||||
'project_id': 'project',
|
||||
'snapshot_id': 'snapshot',
|
||||
})
|
||||
self.marker = FakeTable(user_id='user',
|
||||
project_id='p',
|
||||
snapshot_id='s')
|
||||
self.model = FakeTable
|
||||
|
||||
def test_paginate_query_no_pagination_no_sort_dirs(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc_3')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc_3')
|
||||
self.query.order_by('asc_3').AndReturn(self.query)
|
||||
sqlalchemy.asc('project').AndReturn('asc_2')
|
||||
sqlalchemy.asc(self.model.project_id).AndReturn('asc_2')
|
||||
self.query.order_by('asc_2').AndReturn(self.query)
|
||||
sqlalchemy.asc('snapshot').AndReturn('asc_1')
|
||||
sqlalchemy.asc(self.model.snapshot_id).AndReturn('asc_1')
|
||||
self.query.order_by('asc_1').AndReturn(self.query)
|
||||
self.query.limit(5).AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
@ -117,9 +123,9 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
['user_id', 'project_id', 'snapshot_id'])
|
||||
|
||||
def test_paginate_query_no_pagination(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc')
|
||||
self.query.order_by('asc').AndReturn(self.query)
|
||||
sqlalchemy.desc('project').AndReturn('desc')
|
||||
sqlalchemy.desc(self.model.project_id).AndReturn('desc')
|
||||
self.query.order_by('desc').AndReturn(self.query)
|
||||
self.query.limit(5).AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
@ -128,13 +134,23 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
sort_dirs=['asc', 'desc'])
|
||||
|
||||
def test_paginate_query_attribute_error(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc')
|
||||
self.query.order_by('asc').AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
self.assertRaises(exception.InvalidSortKey,
|
||||
utils.paginate_query, self.query,
|
||||
self.model, 5, ['user_id', 'non-existent key'])
|
||||
|
||||
def test_paginate_query_attribute_error_invalid_sortkey(self):
|
||||
self.assertRaises(exception.InvalidSortKey,
|
||||
utils.paginate_query, self.query,
|
||||
self.model, 5, ['bad_user_id'])
|
||||
|
||||
def test_paginate_query_attribute_error_invalid_sortkey_2(self):
|
||||
self.assertRaises(exception.InvalidSortKey,
|
||||
utils.paginate_query, self.query,
|
||||
self.model, 5, ['foo'])
|
||||
|
||||
def test_paginate_query_assertion_error(self):
|
||||
self.mox.ReplayAll()
|
||||
self.assertRaises(AssertionError,
|
||||
@ -152,13 +168,13 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
sort_dir=None, sort_dirs=['asc', 'desk'])
|
||||
|
||||
def test_paginate_query(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc_1')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
|
||||
self.query.order_by('asc_1').AndReturn(self.query)
|
||||
sqlalchemy.desc('project').AndReturn('desc_1')
|
||||
sqlalchemy.desc(self.model.project_id).AndReturn('desc_1')
|
||||
self.query.order_by('desc_1').AndReturn(self.query)
|
||||
self.mox.StubOutWithMock(sqlalchemy.sql, 'and_')
|
||||
sqlalchemy.sql.and_(False).AndReturn('some_crit')
|
||||
sqlalchemy.sql.and_(True, False).AndReturn('another_crit')
|
||||
sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit')
|
||||
sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit')
|
||||
self.mox.StubOutWithMock(sqlalchemy.sql, 'or_')
|
||||
sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f')
|
||||
self.query.filter('some_f').AndReturn(self.query)
|
||||
@ -170,7 +186,7 @@ class TestPaginateQuery(test_base.BaseTestCase):
|
||||
sort_dirs=['asc', 'desc'])
|
||||
|
||||
def test_paginate_query_value_error(self):
|
||||
sqlalchemy.asc('user').AndReturn('asc_1')
|
||||
sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
|
||||
self.query.order_by('asc_1').AndReturn(self.query)
|
||||
self.mox.ReplayAll()
|
||||
self.assertRaises(ValueError, utils.paginate_query,
|
||||
|
Loading…
x
Reference in New Issue
Block a user