diff --git a/keystone/assignment/role_backends/sql.py b/keystone/assignment/role_backends/sql.py index 6ede9eca6a..685b9e052b 100644 --- a/keystone/assignment/role_backends/sql.py +++ b/keystone/assignment/role_backends/sql.py @@ -11,6 +11,7 @@ # under the License. from keystone import assignment +from keystone.common import driver_hints from keystone.common import sql from keystone import exception @@ -24,7 +25,7 @@ class Role(assignment.RoleDriverV9): session.add(ref) return ref.to_dict() - @sql.truncated + @driver_hints.truncated def list_roles(self, hints): with sql.transaction() as session: query = session.query(RoleTable) diff --git a/keystone/catalog/backends/sql.py b/keystone/catalog/backends/sql.py index b82ca94c77..bf5ed02fe0 100644 --- a/keystone/catalog/backends/sql.py +++ b/keystone/catalog/backends/sql.py @@ -21,6 +21,7 @@ from sqlalchemy.sql import true from keystone import catalog from keystone.catalog import core +from keystone.common import driver_hints from keystone.common import sql from keystone import exception from keystone.i18n import _ @@ -171,7 +172,7 @@ class Catalog(catalog.CatalogDriverV8): return ref.to_dict() # Services - @sql.truncated + @driver_hints.truncated def list_services(self, hints): session = sql.get_session() services = session.query(Service) @@ -240,7 +241,7 @@ class Catalog(catalog.CatalogDriverV8): session = sql.get_session() return self._get_endpoint(session, endpoint_id).to_dict() - @sql.truncated + @driver_hints.truncated def list_endpoints(self, hints): session = sql.get_session() endpoints = session.query(Endpoint) diff --git a/keystone/common/driver_hints.py b/keystone/common/driver_hints.py index f11ad0ed06..b578f985c9 100644 --- a/keystone/common/driver_hints.py +++ b/keystone/common/driver_hints.py @@ -13,6 +13,50 @@ # License for the specific language governing permissions and limitations # under the License. +import functools + +from keystone import exception +from keystone.i18n import _ + + +def truncated(f): + """Ensure list truncation is detected in Driver list entity methods. + + This is designed to wrap Driver list_{entity} methods in order to + calculate if the resultant list has been truncated. Provided a limit dict + is found in the hints list, we increment the limit by one so as to ask the + wrapped function for one more entity than the limit, and then once the list + has been generated, we check to see if the original limit has been + exceeded, in which case we truncate back to that limit and set the + 'truncated' boolean to 'true' in the hints limit dict. + + """ + @functools.wraps(f) + def wrapper(self, hints, *args, **kwargs): + if not hasattr(hints, 'limit'): + raise exception.UnexpectedError( + _('Cannot truncate a driver call without hints list as ' + 'first parameter after self ')) + + if hints.limit is None: + return f(self, hints, *args, **kwargs) + + # A limit is set, so ask for one more entry than we need + list_limit = hints.limit['limit'] + hints.set_limit(list_limit + 1) + ref_list = f(self, hints, *args, **kwargs) + + # If we got more than the original limit then trim back the list and + # mark it truncated. In both cases, make sure we set the limit back + # to its original value. + if len(ref_list) > list_limit: + hints.set_limit(list_limit, truncated=True) + return ref_list[:list_limit] + else: + hints.set_limit(list_limit) + return ref_list + return wrapper + class Hints(object): """Encapsulate driver hints for listing entities. diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py index e03c1f83fa..7c8689d722 100644 --- a/keystone/common/sql/core.py +++ b/keystone/common/sql/core.py @@ -34,6 +34,7 @@ from sqlalchemy.ext import declarative from sqlalchemy.orm.attributes import flag_modified, InstrumentedAttribute from sqlalchemy import types as sql_types +from keystone.common import driver_hints from keystone.common import utils from keystone import exception from keystone.i18n import _ @@ -200,42 +201,7 @@ def transaction(expire_on_commit=False): def truncated(f): - """Ensure list truncation is detected in Driver list entity methods. - - This is designed to wrap and sql Driver list_{entity} methods in order to - calculate if the resultant list has been truncated. Provided a limit dict - is found in the hints list, we increment the limit by one so as to ask the - wrapped function for one more entity than the limit, and then once the list - has been generated, we check to see if the original limit has been - exceeded, in which case we truncate back to that limit and set the - 'truncated' boolean to 'true' in the hints limit dict. - - """ - @functools.wraps(f) - def wrapper(self, hints, *args, **kwargs): - if not hasattr(hints, 'limit'): - raise exception.UnexpectedError( - _('Cannot truncate a driver call without hints list as ' - 'first parameter after self ')) - - if hints.limit is None: - return f(self, hints, *args, **kwargs) - - # A limit is set, so ask for one more entry than we need - list_limit = hints.limit['limit'] - hints.set_limit(list_limit + 1) - ref_list = f(self, hints, *args, **kwargs) - - # If we got more than the original limit then trim back the list and - # mark it truncated. In both cases, make sure we set the limit back - # to its original value. - if len(ref_list) > list_limit: - hints.set_limit(list_limit, truncated=True) - return ref_list[:list_limit] - else: - hints.set_limit(list_limit) - return ref_list - return wrapper + return driver_hints.truncated(f) class _WontMatch(Exception): diff --git a/keystone/credential/backends/sql.py b/keystone/credential/backends/sql.py index 6dc9cd6509..349c03660e 100644 --- a/keystone/credential/backends/sql.py +++ b/keystone/credential/backends/sql.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +from keystone.common import driver_hints from keystone.common import sql from keystone import credential from keystone import exception @@ -41,7 +42,7 @@ class Credential(credential.CredentialDriverV8): session.add(ref) return ref.to_dict() - @sql.truncated + @driver_hints.truncated def list_credentials(self, hints): session = sql.get_session() credentials = session.query(CredentialModel) diff --git a/keystone/identity/backends/sql.py b/keystone/identity/backends/sql.py index a899dbf5b0..d39424be3f 100644 --- a/keystone/identity/backends/sql.py +++ b/keystone/identity/backends/sql.py @@ -14,6 +14,7 @@ from oslo_config import cfg +from keystone.common import driver_hints from keystone.common import sql from keystone.common import utils from keystone import exception @@ -118,7 +119,7 @@ class Identity(identity.IdentityDriverV8): session.add(user_ref) return identity.filter_user(user_ref.to_dict()) - @sql.truncated + @driver_hints.truncated def list_users(self, hints): session = sql.get_session() query = session.query(User) @@ -249,7 +250,7 @@ class Identity(identity.IdentityDriverV8): session.add(ref) return ref.to_dict() - @sql.truncated + @driver_hints.truncated def list_groups(self, hints): session = sql.get_session() query = session.query(Group) diff --git a/keystone/resource/backends/sql.py b/keystone/resource/backends/sql.py index 08f6e52883..210a380339 100644 --- a/keystone/resource/backends/sql.py +++ b/keystone/resource/backends/sql.py @@ -14,6 +14,7 @@ from oslo_config import cfg from oslo_log import log from keystone.common import clean +from keystone.common import driver_hints from keystone.common import sql from keystone import exception from keystone.i18n import _LE @@ -50,7 +51,7 @@ class Resource(keystone_resource.ResourceDriverV8): raise exception.ProjectNotFound(project_id=tenant_name) return project_ref.to_dict() - @sql.truncated + @driver_hints.truncated def list_projects(self, hints): with sql.transaction() as session: query = session.query(Project) @@ -176,7 +177,7 @@ class Resource(keystone_resource.ResourceDriverV8): session.add(ref) return ref.to_dict() - @sql.truncated + @driver_hints.truncated def list_domains(self, hints): with sql.transaction() as session: query = session.query(Domain)