From da6ea7e2243aa13e626a2dbd5407477fcbd79c9c Mon Sep 17 00:00:00 2001 From: Jamie Lennox Date: Thu, 19 May 2016 23:06:14 +1000 Subject: [PATCH] Pass a request to controllers instead of a context Instead of the unformed context dictionary pass a full request object with access to the context_dict so that existing functions still work. After this we can replace smaller usages of the context dict with functions and properties on the request directly. Change-Id: Ibe822ed7c76a24a7d31d98ce62f873a01e5fb213 --- keystone/assignment/controllers.py | 177 +++++------ keystone/auth/controllers.py | 60 ++-- keystone/catalog/controllers.py | 183 ++++++------ keystone/common/controller.py | 44 +-- keystone/common/wsgi.py | 2 +- keystone/contrib/ec2/controllers.py | 51 ++-- keystone/credential/controllers.py | 22 +- keystone/endpoint_policy/controllers.py | 27 +- keystone/federation/controllers.py | 120 ++++---- keystone/identity/controllers.py | 135 +++++---- keystone/oauth1/controllers.py | 86 +++--- keystone/policy/controllers.py | 27 +- keystone/resource/controllers.py | 107 +++---- keystone/revoke/controllers.py | 8 +- keystone/tests/unit/core.py | 15 + .../tests/unit/identity/test_controllers.py | 10 +- .../tests/unit/resource/test_controllers.py | 8 +- keystone/tests/unit/test_auth.py | 282 +++++++++--------- keystone/tests/unit/test_auth_plugin.py | 17 +- keystone/tests/unit/test_cert_setup.py | 2 +- keystone/tests/unit/test_v2_controller.py | 42 ++- keystone/tests/unit/test_v3.py | 7 +- keystone/tests/unit/test_v3_auth.py | 55 ++-- keystone/tests/unit/test_v3_federation.py | 94 +++--- keystone/tests/unit/test_wsgi.py | 28 +- keystone/token/controllers.py | 30 +- keystone/trust/controllers.py | 55 ++-- keystone/v2_crud/user_crud.py | 16 +- keystone/version/controllers.py | 20 +- 29 files changed, 906 insertions(+), 824 deletions(-) diff --git a/keystone/assignment/controllers.py b/keystone/assignment/controllers.py index 1e2267637..c5a07a5ee 100644 --- a/keystone/assignment/controllers.py +++ b/keystone/assignment/controllers.py @@ -42,7 +42,7 @@ class TenantAssignment(controller.V2Controller): """The V2 Project APIs that are processing assignments.""" @controller.v2_auth_deprecated - def get_projects_for_token(self, context, **kw): + def get_projects_for_token(self, request, **kw): """Get valid tenants for token based on token used to authenticate. Pulls the token from the context, validates it and gets the valid @@ -51,21 +51,21 @@ class TenantAssignment(controller.V2Controller): Doesn't care about token scopedness. """ - token_ref = utils.get_token_ref(context) + token_ref = utils.get_token_ref(request.context_dict) tenant_refs = ( self.assignment_api.list_projects_for_user(token_ref.user_id)) tenant_refs = [self.v3_to_v2_project(ref) for ref in tenant_refs if ref['domain_id'] == CONF.identity.default_domain_id] params = { - 'limit': context['query_string'].get('limit'), - 'marker': context['query_string'].get('marker'), + 'limit': request.context_dict['query_string'].get('limit'), + 'marker': request.context_dict['query_string'].get('marker'), } return self.format_project_list(tenant_refs, **params) @controller.v2_deprecated - def get_project_users(self, context, tenant_id, **kw): - self.assert_admin(context) + def get_project_users(self, request, tenant_id, **kw): + self.assert_admin(request.context_dict) user_refs = [] user_ids = self.assignment_api.list_user_ids_for_project(tenant_id) for user_id in user_ids: @@ -87,14 +87,14 @@ class Role(controller.V2Controller): """The Role management APIs.""" @controller.v2_deprecated - def get_role(self, context, role_id): - self.assert_admin(context) + def get_role(self, request, role_id): + self.assert_admin(request.context_dict) return {'role': self.role_api.get_role(role_id)} @controller.v2_deprecated - def create_role(self, context, role): + def create_role(self, request, role): role = self._normalize_dict(role) - self.assert_admin(context) + self.assert_admin(request.context_dict) if 'name' not in role or not role['name']: msg = _('Name field is required and cannot be empty') @@ -109,19 +109,19 @@ class Role(controller.V2Controller): role_id = uuid.uuid4().hex role['id'] = role_id - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) role_ref = self.role_api.create_role(role_id, role, initiator) return {'role': role_ref} @controller.v2_deprecated - def delete_role(self, context, role_id): - self.assert_admin(context) - initiator = notifications._get_request_audit_info(context) + def delete_role(self, request, role_id): + self.assert_admin(request.context_dict) + initiator = notifications._get_request_audit_info(request.context_dict) self.role_api.delete_role(role_id, initiator) @controller.v2_deprecated - def get_roles(self, context): - self.assert_admin(context) + def get_roles(self, request): + self.assert_admin(request.context_dict) return {'roles': self.role_api.list_roles()} @@ -131,14 +131,14 @@ class RoleAssignmentV2(controller.V2Controller): # COMPAT(essex-3) @controller.v2_deprecated - def get_user_roles(self, context, user_id, tenant_id=None): + def get_user_roles(self, request, user_id, tenant_id=None): """Get the roles for a user and tenant pair. Since we're trying to ignore the idea of user-only roles we're not implementing them in hopes that the idea will die off. """ - self.assert_admin(context) + self.assert_admin(request.context_dict) # NOTE(davechen): Router without project id is defined, # but we don't plan on implementing this. if tenant_id is None: @@ -150,14 +150,14 @@ class RoleAssignmentV2(controller.V2Controller): for x in roles]} @controller.v2_deprecated - def add_role_to_user(self, context, user_id, role_id, tenant_id=None): + def add_role_to_user(self, request, user_id, role_id, tenant_id=None): """Add a role to a user and tenant pair. Since we're trying to ignore the idea of user-only roles we're not implementing them in hopes that the idea will die off. """ - self.assert_admin(context) + self.assert_admin(request.context_dict) if tenant_id is None: raise exception.NotImplemented( message=_('User roles not supported: tenant_id required')) @@ -169,14 +169,14 @@ class RoleAssignmentV2(controller.V2Controller): return {'role': role_ref} @controller.v2_deprecated - def remove_role_from_user(self, context, user_id, role_id, tenant_id=None): + def remove_role_from_user(self, request, user_id, role_id, tenant_id=None): """Remove a role from a user and tenant pair. Since we're trying to ignore the idea of user-only roles we're not implementing them in hopes that the idea will die off. """ - self.assert_admin(context) + self.assert_admin(request.context_dict) if tenant_id is None: raise exception.NotImplemented( message=_('User roles not supported: tenant_id required')) @@ -188,7 +188,7 @@ class RoleAssignmentV2(controller.V2Controller): # COMPAT(diablo): CRUD extension @controller.v2_deprecated - def get_role_refs(self, context, user_id): + def get_role_refs(self, request, user_id): """Ultimate hack to get around having to make role_refs first-class. This will basically iterate over the various roles the user has in @@ -197,7 +197,7 @@ class RoleAssignmentV2(controller.V2Controller): up the appropriate data when we need to delete them. """ - self.assert_admin(context) + self.assert_admin(request.context_dict) tenants = self.assignment_api.list_projects_for_user(user_id) o = [] for tenant in tenants: @@ -217,14 +217,14 @@ class RoleAssignmentV2(controller.V2Controller): # COMPAT(diablo): CRUD extension @controller.v2_deprecated - def create_role_ref(self, context, user_id, role): + def create_role_ref(self, request, user_id, role): """Used for adding a user to a tenant. In the legacy data model adding a user to a tenant required setting a role. """ - self.assert_admin(context) + self.assert_admin(request.context_dict) # TODO(termie): for now we're ignoring the actual role tenant_id = role.get('tenantId') role_id = role.get('roleId') @@ -236,7 +236,7 @@ class RoleAssignmentV2(controller.V2Controller): # COMPAT(diablo): CRUD extension @controller.v2_deprecated - def delete_role_ref(self, context, user_id, role_ref_id): + def delete_role_ref(self, request, user_id, role_ref_id): """Used for deleting a user from a tenant. In the legacy data model removing a user from a tenant required @@ -247,7 +247,7 @@ class RoleAssignmentV2(controller.V2Controller): we remove the user from the tenant. """ - self.assert_admin(context) + self.assert_admin(request.context_dict) # TODO(termie): for now we're ignoring the actual role role_ref_ref = urllib.parse.parse_qs(role_ref_id) tenant_id = role_ref_ref.get('tenantId')[0] @@ -268,11 +268,14 @@ class ProjectAssignmentV3(controller.V3Controller): self.get_member_from_driver = self.resource_api.get_project @controller.filterprotected('domain_id', 'enabled', 'name') - def list_user_projects(self, context, filters, user_id): - hints = ProjectAssignmentV3.build_driver_hints(context, filters) + def list_user_projects(self, request, filters, user_id): + hints = ProjectAssignmentV3.build_driver_hints(request.context_dict, + filters) refs = self.assignment_api.list_projects_for_user(user_id, hints=hints) - return ProjectAssignmentV3.wrap_collection(context, refs, hints=hints) + return ProjectAssignmentV3.wrap_collection(request.context_dict, + refs, + hints=hints) @dependency.requires('role_api') @@ -319,33 +322,33 @@ class RoleV3(controller.V3Controller): @controller.protected() @validation.validated(schema.role_create, 'role') - def create_role(self, context, role): - return self._create_role(context, role) + def create_role(self, request, role): + return self._create_role(request.context_dict, role) @controller.protected() @validation.validated(schema.role_create, 'role') - def create_domain_role(self, context, role): - return self._create_role(context, role) + def create_domain_role(self, request, role): + return self._create_role(request.context_dict, role) - def list_roles_wrapper(self, context): + def list_roles_wrapper(self, request): # If there is no domain_id filter defined, then we only want to return # global roles, so we set the domain_id filter to None. - params = context['query_string'] + params = request.context_dict['query_string'] if 'domain_id' not in params: - context['query_string']['domain_id'] = None + request.context_dict['query_string']['domain_id'] = None - if context['query_string']['domain_id'] is not None: - return self.list_domain_roles(context) + if request.context_dict['query_string']['domain_id'] is not None: + return self.list_domain_roles(request) else: - return self.list_roles(context) + return self.list_roles(request) @controller.filterprotected('name', 'domain_id') - def list_roles(self, context, filters): - return self._list_roles(context, filters) + def list_roles(self, request, filters): + return self._list_roles(request.context_dict, filters) @controller.filterprotected('name', 'domain_id') - def list_domain_roles(self, context, filters): - return self._list_roles(context, filters) + def list_domain_roles(self, request, filters): + return self._list_roles(request.context_dict, filters) def get_role_wrapper(self, context, role_id): if self._is_domain_role_target(role_id): @@ -354,12 +357,12 @@ class RoleV3(controller.V3Controller): return self.get_role(context, role_id=role_id) @controller.protected() - def get_role(self, context, role_id): - return self._get_role(context, role_id) + def get_role(self, request, role_id): + return self._get_role(request.context_dict, role_id) @controller.protected() - def get_domain_role(self, context, role_id): - return self._get_role(context, role_id) + def get_domain_role(self, request, role_id): + return self._get_role(request.context_dict, role_id) def update_role_wrapper(self, context, role_id, role): # Since we don't allow you change whether a role is global or domain @@ -373,13 +376,13 @@ class RoleV3(controller.V3Controller): @controller.protected() @validation.validated(schema.role_update, 'role') - def update_role(self, context, role_id, role): - return self._update_role(context, role_id, role) + def update_role(self, request, role_id, role): + return self._update_role(request.context_dict, role_id, role) @controller.protected() @validation.validated(schema.role_update, 'role') - def update_domain_role(self, context, role_id, role): - return self._update_role(context, role_id, role) + def update_domain_role(self, request, role_id, role): + return self._update_role(request.context_dict, role_id, role) def delete_role_wrapper(self, context, role_id): if self._is_domain_role_target(role_id): @@ -388,12 +391,12 @@ class RoleV3(controller.V3Controller): return self.delete_role(context, role_id=role_id) @controller.protected() - def delete_role(self, context, role_id): - return self._delete_role(context, role_id) + def delete_role(self, request, role_id): + return self._delete_role(request.context_dict, role_id) @controller.protected() - def delete_domain_role(self, context, role_id): - return self._delete_role(context, role_id) + def delete_domain_role(self, request, role_id): + return self._delete_role(request.context_dict, role_id) def _create_role(self, context, role): if role['name'] == CONF.member_role_name: @@ -484,38 +487,40 @@ class ImpliedRolesV3(controller.V3Controller): return response @controller.protected() - def get_implied_role(self, context, prior_role_id, implied_role_id): + def get_implied_role(self, request, prior_role_id, implied_role_id): ref = self.role_api.get_implied_role(prior_role_id, implied_role_id) prior_id = ref['prior_role_id'] implied_id = ref['implied_role_id'] endpoint = super(controller.V3Controller, ImpliedRolesV3).base_url( - context, 'public') + request.context_dict, 'public') response = self._populate_implied_role_response( endpoint, prior_id, implied_id) return response @controller.protected() - def check_implied_role(self, context, prior_role_id, implied_role_id): + def check_implied_role(self, request, prior_role_id, implied_role_id): self.role_api.get_implied_role(prior_role_id, implied_role_id) @controller.protected() - def create_implied_role(self, context, prior_role_id, implied_role_id): + def create_implied_role(self, request, prior_role_id, implied_role_id): self.role_api.create_implied_role(prior_role_id, implied_role_id) return wsgi.render_response( - self.get_implied_role(context, prior_role_id, implied_role_id), + self.get_implied_role(request, + prior_role_id, + implied_role_id), status=(201, 'Created')) @controller.protected() - def delete_implied_role(self, context, prior_role_id, implied_role_id): + def delete_implied_role(self, request, prior_role_id, implied_role_id): self.role_api.delete_implied_role(prior_role_id, implied_role_id) @controller.protected() - def list_implied_roles(self, context, prior_role_id): + def list_implied_roles(self, request, prior_role_id): ref = self.role_api.list_implied_roles(prior_role_id) implied_ids = [r['implied_role_id'] for r in ref] endpoint = super(controller.V3Controller, ImpliedRolesV3).base_url( - context, 'public') + request.context_dict, 'public') results = self._populate_implied_roles_response( endpoint, prior_role_id, implied_ids) @@ -523,14 +528,14 @@ class ImpliedRolesV3(controller.V3Controller): return results @controller.protected() - def list_role_inference_rules(self, context): + def list_role_inference_rules(self, request): refs = self.role_api.list_role_inference_rules() role_dict = {role_ref['id']: role_ref for role_ref in self.role_api.list_roles()} rules = dict() endpoint = super(controller.V3Controller, ImpliedRolesV3).base_url( - context, 'public') + request.context_dict, 'public') for ref in refs: implied_role_id = ref['implied_role_id'] @@ -614,7 +619,7 @@ class GrantAssignmentV3(controller.V3Controller): self.check_protection(context, protection, ref) @controller.protected(callback=_check_grant_protection) - def create_grant(self, context, role_id, user_id=None, + def create_grant(self, request, role_id, user_id=None, group_id=None, domain_id=None, project_id=None): """Grant a role to a user or group on either a domain or project.""" self._require_domain_xor_project(domain_id, project_id) @@ -622,10 +627,11 @@ class GrantAssignmentV3(controller.V3Controller): self.assignment_api.create_grant( role_id, user_id, group_id, domain_id, project_id, - self._check_if_inherited(context), context) + self._check_if_inherited(request.context_dict), + request.context_dict) @controller.protected(callback=_check_grant_protection) - def list_grants(self, context, user_id=None, + def list_grants(self, request, user_id=None, group_id=None, domain_id=None, project_id=None): """List roles granted to user/group on either a domain or project.""" self._require_domain_xor_project(domain_id, project_id) @@ -633,11 +639,11 @@ class GrantAssignmentV3(controller.V3Controller): refs = self.assignment_api.list_grants( user_id, group_id, domain_id, project_id, - self._check_if_inherited(context)) - return GrantAssignmentV3.wrap_collection(context, refs) + self._check_if_inherited(request.context_dict)) + return GrantAssignmentV3.wrap_collection(request.context_dict, refs) @controller.protected(callback=_check_grant_protection) - def check_grant(self, context, role_id, user_id=None, + def check_grant(self, request, role_id, user_id=None, group_id=None, domain_id=None, project_id=None): """Check if a role has been granted on either a domain or project.""" self._require_domain_xor_project(domain_id, project_id) @@ -645,14 +651,14 @@ class GrantAssignmentV3(controller.V3Controller): self.assignment_api.get_grant( role_id, user_id, group_id, domain_id, project_id, - self._check_if_inherited(context)) + self._check_if_inherited(request.context_dict)) # NOTE(lbragstad): This will allow users to clean up role assignments # from the backend in the event the user was removed prior to the role # assignment being removed. @controller.protected(callback=functools.partial( _check_grant_protection, allow_no_user=True)) - def revoke_grant(self, context, role_id, user_id=None, + def revoke_grant(self, request, role_id, user_id=None, group_id=None, domain_id=None, project_id=None): """Revoke a role from user/group on either a domain or project.""" self._require_domain_xor_project(domain_id, project_id) @@ -660,7 +666,8 @@ class GrantAssignmentV3(controller.V3Controller): self.assignment_api.delete_grant( role_id, user_id, group_id, domain_id, project_id, - self._check_if_inherited(context), context) + self._check_if_inherited(request.context_dict), + request.context_dict) @dependency.requires('assignment_api', 'identity_api', 'resource_api') @@ -925,8 +932,8 @@ class RoleAssignmentV3(controller.V3Controller): @controller.filterprotected('group.id', 'role.id', 'scope.domain.id', 'scope.project.id', 'scope.OS-INHERIT:inherited_to', 'user.id') - def list_role_assignments(self, context, filters): - return self._list_role_assignments(context, filters) + def list_role_assignments(self, request, filters): + return self._list_role_assignments(request.context_dict, filters) def _check_list_tree_protection(self, context, protection_info): """Check protection for list assignment for tree API. @@ -947,15 +954,15 @@ class RoleAssignmentV3(controller.V3Controller): 'scope.domain.id', 'scope.project.id', 'scope.OS-INHERIT:inherited_to', 'user.id', callback=_check_list_tree_protection) - def list_role_assignments_for_tree(self, context, filters): - if not context['query_string'].get('scope.project.id'): + def list_role_assignments_for_tree(self, request, filters): + if not request.context_dict['query_string'].get('scope.project.id'): msg = _('scope.project.id must be specified if include_subtree ' 'is also specified') raise exception.ValidationError(message=msg) - return self._list_role_assignments(context, filters, + return self._list_role_assignments(request.context_dict, filters, include_subtree=True) - def list_role_assignments_wrapper(self, context): + def list_role_assignments_wrapper(self, request): """Main entry point from router for list role assignments. Since we want different policy file rules to be applicable based on @@ -964,9 +971,9 @@ class RoleAssignmentV3(controller.V3Controller): protected entry point. """ - params = context['query_string'] + params = request.context_dict['query_string'] if 'include_subtree' in params and ( self.query_filter_is_true(params['include_subtree'])): - return self.list_role_assignments_for_tree(context) + return self.list_role_assignments_for_tree(request) else: - return self.list_role_assignments(context) + return self.list_role_assignments(request) diff --git a/keystone/auth/controllers.py b/keystone/auth/controllers.py index ca01afd47..6c324c752 100644 --- a/keystone/auth/controllers.py +++ b/keystone/auth/controllers.py @@ -391,16 +391,17 @@ class Auth(controller.V3Controller): super(Auth, self).__init__(*args, **kw) config.setup_authentication() - def authenticate_for_token(self, context, auth=None): + def authenticate_for_token(self, request, auth=None): """Authenticate user and issue a token.""" - include_catalog = 'nocatalog' not in context['query_string'] + query_string = request.context_dict['query_string'] + include_catalog = 'nocatalog' not in query_string try: - auth_info = AuthInfo.create(context, auth=auth) + auth_info = AuthInfo.create(request.context_dict, auth=auth) auth_context = AuthContext(extras={}, method_names=[], bind={}) - self.authenticate(context, auth_info, auth_context) + self.authenticate(request, auth_info, auth_context) if auth_context.get('access_token_id'): auth_info.set_scope(None, auth_context['project_id'], None) self._check_and_set_default_scoping(auth_info, auth_context) @@ -496,15 +497,17 @@ class Auth(controller.V3Controller): LOG.warning(msg, {'user_id': user_ref['id'], 'project_id': default_project_id}) - def authenticate(self, context, auth_info, auth_context): + def authenticate(self, request, auth_info, auth_context): """Authenticate user.""" # The 'external' method allows any 'REMOTE_USER' based authentication # In some cases the server can set REMOTE_USER as '' instead of # dropping it, so this must be filtered out - if context['environment'].get('REMOTE_USER'): + if request.context_dict['environment'].get('REMOTE_USER'): try: external = get_auth_method('external') - external.authenticate(context, auth_info, auth_context) + external.authenticate(request.context_dict, + auth_info, + auth_context) except exception.AuthMethodNotSupported: # This will happen there is no 'external' plugin registered # and the container is performing authentication. @@ -523,7 +526,7 @@ class Auth(controller.V3Controller): auth_response = {'methods': []} for method_name in auth_info.get_method_names(): method = get_auth_method(method_name) - resp = method.authenticate(context, + resp = method.authenticate(request.context_dict, auth_info.get_method_data(method_name), auth_context) if resp: @@ -539,8 +542,8 @@ class Auth(controller.V3Controller): raise exception.Unauthorized(msg) @controller.protected() - def check_token(self, context): - token_id = context.get('subject_token_id') + def check_token(self, request): + token_id = request.context_dict.get('subject_token_id') token_data = self.token_provider_api.validate_v3_token( token_id) # NOTE(morganfainberg): The code in @@ -549,14 +552,15 @@ class Auth(controller.V3Controller): return render_token_data_response(token_id, token_data) @controller.protected() - def revoke_token(self, context): - token_id = context.get('subject_token_id') + def revoke_token(self, request): + token_id = request.context_dict.get('subject_token_id') return self.token_provider_api.revoke_token(token_id) @controller.protected() - def validate_token(self, context): - token_id = context.get('subject_token_id') - include_catalog = 'nocatalog' not in context['query_string'] + def validate_token(self, request): + token_id = request.context_dict.get('subject_token_id') + query_string = request.context_dict['query_string'] + include_catalog = 'nocatalog' not in query_string token_data = self.token_provider_api.validate_v3_token( token_id) if not include_catalog and 'catalog' in token_data['token']: @@ -564,11 +568,12 @@ class Auth(controller.V3Controller): return render_token_data_response(token_id, token_data) @controller.protected() - def revocation_list(self, context, auth=None): + def revocation_list(self, request, auth=None): if not CONF.token.revoke_by_id: raise exception.Gone() - audit_id_only = ('audit_id_only' in context['query_string']) + query_string = request.context_dict['query_string'] + audit_id_only = 'audit_id_only' in query_string tokens = self.token_provider_api.list_revoked_tokens() @@ -600,8 +605,8 @@ class Auth(controller.V3Controller): return a or b @controller.protected() - def get_auth_projects(self, context): - auth_context = self.get_auth_context(context) + def get_auth_projects(self, request): + auth_context = self.get_auth_context(request.context_dict) user_id = auth_context.get('user_id') user_refs = [] @@ -618,11 +623,12 @@ class Auth(controller.V3Controller): grp_refs = self.assignment_api.list_projects_for_groups(group_ids) refs = self._combine_lists_uniquely(user_refs, grp_refs) - return resource_controllers.ProjectV3.wrap_collection(context, refs) + return resource_controllers.ProjectV3.wrap_collection( + request.context_dict, refs) @controller.protected() - def get_auth_domains(self, context): - auth_context = self.get_auth_context(context) + def get_auth_domains(self, request): + auth_context = self.get_auth_context(request.context_dict) user_id = auth_context.get('user_id') user_refs = [] @@ -639,11 +645,12 @@ class Auth(controller.V3Controller): grp_refs = self.assignment_api.list_domains_for_groups(group_ids) refs = self._combine_lists_uniquely(user_refs, grp_refs) - return resource_controllers.DomainV3.wrap_collection(context, refs) + return resource_controllers.DomainV3.wrap_collection( + request.context_dict, refs) @controller.protected() - def get_auth_catalog(self, context): - auth_context = self.get_auth_context(context) + def get_auth_catalog(self, request): + auth_context = self.get_auth_context(request.context_dict) user_id = auth_context.get('user_id') project_id = auth_context.get('project_id') @@ -660,7 +667,8 @@ class Auth(controller.V3Controller): # several private methods. return { 'catalog': self.catalog_api.get_v3_catalog(user_id, project_id), - 'links': {'self': self.base_url(context, path='auth/catalog')} + 'links': {'self': self.base_url(request.context_dict, + path='auth/catalog')} } diff --git a/keystone/catalog/controllers.py b/keystone/catalog/controllers.py index aa7a6aa48..e4da5b6d6 100644 --- a/keystone/catalog/controllers.py +++ b/keystone/catalog/controllers.py @@ -36,30 +36,30 @@ INTERFACES = ['public', 'internal', 'admin'] class Service(controller.V2Controller): @controller.v2_deprecated - def get_services(self, context): - self.assert_admin(context) + def get_services(self, request): + self.assert_admin(request.context_dict) service_list = self.catalog_api.list_services() return {'OS-KSADM:services': service_list} @controller.v2_deprecated - def get_service(self, context, service_id): - self.assert_admin(context) + def get_service(self, request, service_id): + self.assert_admin(request.context_dict) service_ref = self.catalog_api.get_service(service_id) return {'OS-KSADM:service': service_ref} @controller.v2_deprecated - def delete_service(self, context, service_id): - self.assert_admin(context) - initiator = notifications._get_request_audit_info(context) + def delete_service(self, request, service_id): + self.assert_admin(request.context_dict) + initiator = notifications._get_request_audit_info(request.context_dict) self.catalog_api.delete_service(service_id, initiator) @controller.v2_deprecated - def create_service(self, context, OS_KSADM_service): - self.assert_admin(context) + def create_service(self, request, OS_KSADM_service): + self.assert_admin(request.context_dict) service_id = uuid.uuid4().hex service_ref = OS_KSADM_service.copy() service_ref['id'] = service_id - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) new_service_ref = self.catalog_api.create_service( service_id, service_ref, initiator) return {'OS-KSADM:service': new_service_ref} @@ -69,9 +69,9 @@ class Service(controller.V2Controller): class Endpoint(controller.V2Controller): @controller.v2_deprecated - def get_endpoints(self, context): + def get_endpoints(self, request): """Merge matching v3 endpoint refs into legacy refs.""" - self.assert_admin(context) + self.assert_admin(request.context_dict) legacy_endpoints = {} v3_endpoints = {} for endpoint in self.catalog_api.list_endpoints(): @@ -129,9 +129,9 @@ class Endpoint(controller.V2Controller): return {'endpoints': list(legacy_endpoints.values())} @controller.v2_deprecated - def create_endpoint(self, context, endpoint): + def create_endpoint(self, request, endpoint): """Create three v3 endpoint refs based on a legacy ref.""" - self.assert_admin(context) + self.assert_admin(request.context_dict) # according to the v2 spec publicurl is mandatory self._require_attribute(endpoint, 'publicurl') @@ -146,7 +146,7 @@ class Endpoint(controller.V2Controller): if interface_url: utils.check_endpoint_url(interface_url) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) if endpoint.get('region') is not None: try: @@ -184,10 +184,10 @@ class Endpoint(controller.V2Controller): return {'endpoint': legacy_endpoint_ref} @controller.v2_deprecated - def delete_endpoint(self, context, endpoint_id): + def delete_endpoint(self, request, endpoint_id): """Delete up to three v3 endpoint refs based on a legacy ref ID.""" - self.assert_admin(context) - initiator = notifications._get_request_audit_info(context) + self.assert_admin(request.context_dict) + initiator = notifications._get_request_audit_info(request.context_dict) deleted_at_least_one = False for endpoint in self.catalog_api.list_endpoints(): @@ -221,40 +221,42 @@ class RegionV3(controller.V3Controller): @controller.protected() @validation.validated(schema.region_create, 'region') - def create_region(self, context, region): + def create_region(self, request, region): ref = self._normalize_dict(region) if not ref.get('id'): ref = self._assign_unique_id(ref) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.catalog_api.create_region(ref, initiator) return wsgi.render_response( - RegionV3.wrap_member(context, ref), + RegionV3.wrap_member(request.context_dict, ref), status=(201, 'Created')) @controller.filterprotected('parent_region_id') - def list_regions(self, context, filters): - hints = RegionV3.build_driver_hints(context, filters) + def list_regions(self, request, filters): + hints = RegionV3.build_driver_hints(request.context_dict, filters) refs = self.catalog_api.list_regions(hints) - return RegionV3.wrap_collection(context, refs, hints=hints) + return RegionV3.wrap_collection(request.context_dict, + refs, + hints=hints) @controller.protected() - def get_region(self, context, region_id): + def get_region(self, request, region_id): ref = self.catalog_api.get_region(region_id) - return RegionV3.wrap_member(context, ref) + return RegionV3.wrap_member(request.context_dict, ref) @controller.protected() @validation.validated(schema.region_update, 'region') - def update_region(self, context, region_id, region): + def update_region(self, request, region_id, region): self._require_matching_id(region_id, region) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.catalog_api.update_region(region_id, region, initiator) - return RegionV3.wrap_member(context, ref) + return RegionV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_region(self, context, region_id): - initiator = notifications._get_request_audit_info(context) + def delete_region(self, request, region_id): + initiator = notifications._get_request_audit_info(request.context_dict) return self.catalog_api.delete_region(region_id, initiator) @@ -269,34 +271,36 @@ class ServiceV3(controller.V3Controller): @controller.protected() @validation.validated(schema.service_create, 'service') - def create_service(self, context, service): + def create_service(self, request, service): ref = self._assign_unique_id(self._normalize_dict(service)) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.catalog_api.create_service(ref['id'], ref, initiator) - return ServiceV3.wrap_member(context, ref) + return ServiceV3.wrap_member(request.context_dict, ref) @controller.filterprotected('type', 'name') - def list_services(self, context, filters): - hints = ServiceV3.build_driver_hints(context, filters) + def list_services(self, request, filters): + hints = ServiceV3.build_driver_hints(request.context_dict, filters) refs = self.catalog_api.list_services(hints=hints) - return ServiceV3.wrap_collection(context, refs, hints=hints) + return ServiceV3.wrap_collection(request.context_dict, + refs, + hints=hints) @controller.protected() - def get_service(self, context, service_id): + def get_service(self, request, service_id): ref = self.catalog_api.get_service(service_id) - return ServiceV3.wrap_member(context, ref) + return ServiceV3.wrap_member(request.context_dict, ref) @controller.protected() @validation.validated(schema.service_update, 'service') - def update_service(self, context, service_id, service): + def update_service(self, request, service_id, service): self._require_matching_id(service_id, service) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.catalog_api.update_service(service_id, service, initiator) - return ServiceV3.wrap_member(context, ref) + return ServiceV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_service(self, context, service_id): - initiator = notifications._get_request_audit_info(context) + def delete_service(self, request, service_id): + initiator = notifications._get_request_audit_info(request.context_dict) return self.catalog_api.delete_service(service_id, initiator) @@ -347,40 +351,43 @@ class EndpointV3(controller.V3Controller): @controller.protected() @validation.validated(schema.endpoint_create, 'endpoint') - def create_endpoint(self, context, endpoint): + def create_endpoint(self, request, endpoint): utils.check_endpoint_url(endpoint['url']) ref = self._assign_unique_id(self._normalize_dict(endpoint)) - ref = self._validate_endpoint_region(ref, context) - initiator = notifications._get_request_audit_info(context) + ref = self._validate_endpoint_region(ref, request.context_dict) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.catalog_api.create_endpoint(ref['id'], ref, initiator) - return EndpointV3.wrap_member(context, ref) + return EndpointV3.wrap_member(request.context_dict, ref) @controller.filterprotected('interface', 'service_id', 'region_id') - def list_endpoints(self, context, filters): - hints = EndpointV3.build_driver_hints(context, filters) + def list_endpoints(self, request, filters): + hints = EndpointV3.build_driver_hints(request.context_dict, filters) refs = self.catalog_api.list_endpoints(hints=hints) - return EndpointV3.wrap_collection(context, refs, hints=hints) + return EndpointV3.wrap_collection(request.context_dict, + refs, + hints=hints) @controller.protected() - def get_endpoint(self, context, endpoint_id): + def get_endpoint(self, request, endpoint_id): ref = self.catalog_api.get_endpoint(endpoint_id) - return EndpointV3.wrap_member(context, ref) + return EndpointV3.wrap_member(request.context_dict, ref) @controller.protected() @validation.validated(schema.endpoint_update, 'endpoint') - def update_endpoint(self, context, endpoint_id, endpoint): + def update_endpoint(self, request, endpoint_id, endpoint): self._require_matching_id(endpoint_id, endpoint) - endpoint = self._validate_endpoint_region(endpoint.copy(), context) + endpoint = self._validate_endpoint_region(endpoint.copy(), + request.context_dict) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.catalog_api.update_endpoint(endpoint_id, endpoint, initiator) - return EndpointV3.wrap_member(context, ref) + return EndpointV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_endpoint(self, context, endpoint_id): - initiator = notifications._get_request_audit_info(context) + def delete_endpoint(self, request, endpoint_id): + initiator = notifications._get_request_audit_info(request.context_dict) return self.catalog_api.delete_endpoint(endpoint_id, initiator) @@ -407,7 +414,7 @@ class EndpointFilterV3Controller(controller.V3Controller): project_or_endpoint_id) @controller.protected() - def add_endpoint_to_project(self, context, project_id, endpoint_id): + def add_endpoint_to_project(self, request, project_id, endpoint_id): """Establish an association between an endpoint and a project.""" # NOTE(gyee): we just need to make sure endpoint and project exist # first. We don't really care whether if project is disabled. @@ -419,7 +426,7 @@ class EndpointFilterV3Controller(controller.V3Controller): project_id) @controller.protected() - def check_endpoint_in_project(self, context, project_id, endpoint_id): + def check_endpoint_in_project(self, request, project_id, endpoint_id): """Verify endpoint is currently associated with given project.""" self.catalog_api.get_endpoint(endpoint_id) self.resource_api.get_project(project_id) @@ -427,31 +434,32 @@ class EndpointFilterV3Controller(controller.V3Controller): project_id) @controller.protected() - def list_endpoints_for_project(self, context, project_id): + def list_endpoints_for_project(self, request, project_id): """List all endpoints currently associated with a given project.""" self.resource_api.get_project(project_id) filtered_endpoints = self.catalog_api.list_endpoints_for_project( project_id) return EndpointV3.wrap_collection( - context, [v for v in six.itervalues(filtered_endpoints)]) + request.context_dict, + [v for v in six.itervalues(filtered_endpoints)]) @controller.protected() - def remove_endpoint_from_project(self, context, project_id, endpoint_id): + def remove_endpoint_from_project(self, request, project_id, endpoint_id): """Remove the endpoint from the association with given project.""" self.catalog_api.remove_endpoint_from_project(endpoint_id, project_id) @controller.protected() - def list_projects_for_endpoint(self, context, endpoint_id): + def list_projects_for_endpoint(self, request, endpoint_id): """Return a list of projects associated with the endpoint.""" self.catalog_api.get_endpoint(endpoint_id) refs = self.catalog_api.list_projects_for_endpoint(endpoint_id) projects = [self.resource_api.get_project( ref['project_id']) for ref in refs] - return resource.controllers.ProjectV3.wrap_collection(context, - projects) + return resource.controllers.ProjectV3.wrap_collection( + request.context_dict, projects) @dependency.requires('catalog_api', 'resource_api') @@ -473,13 +481,13 @@ class EndpointGroupV3Controller(controller.V3Controller): @controller.protected() @validation.validated(schema.endpoint_group_create, 'endpoint_group') - def create_endpoint_group(self, context, endpoint_group): + def create_endpoint_group(self, request, endpoint_group): """Create an Endpoint Group with the associated filters.""" ref = self._assign_unique_id(self._normalize_dict(endpoint_group)) self._require_attribute(ref, 'filters') self._require_valid_filter(ref) ref = self.catalog_api.create_endpoint_group(ref['id'], ref) - return EndpointGroupV3Controller.wrap_member(context, ref) + return EndpointGroupV3Controller.wrap_member(request.context_dict, ref) def _require_valid_filter(self, endpoint_group): filters = endpoint_group.get('filters') @@ -493,15 +501,15 @@ class EndpointGroupV3Controller(controller.V3Controller): return ' or '.join(self.VALID_FILTER_KEYS) @controller.protected() - def get_endpoint_group(self, context, endpoint_group_id): + def get_endpoint_group(self, request, endpoint_group_id): """Retrieve the endpoint group associated with the id if exists.""" ref = self.catalog_api.get_endpoint_group(endpoint_group_id) return EndpointGroupV3Controller.wrap_member( - context, ref) + request.context_dict, ref) @controller.protected() @validation.validated(schema.endpoint_group_update, 'endpoint_group') - def update_endpoint_group(self, context, endpoint_group_id, + def update_endpoint_group(self, request, endpoint_group_id, endpoint_group): """Update fixed values and/or extend the filters.""" if 'filters' in endpoint_group: @@ -509,30 +517,30 @@ class EndpointGroupV3Controller(controller.V3Controller): ref = self.catalog_api.update_endpoint_group(endpoint_group_id, endpoint_group) return EndpointGroupV3Controller.wrap_member( - context, ref) + request.context_dict, ref) @controller.protected() - def delete_endpoint_group(self, context, endpoint_group_id): + def delete_endpoint_group(self, request, endpoint_group_id): """Delete endpoint_group.""" self.catalog_api.delete_endpoint_group(endpoint_group_id) @controller.protected() - def list_endpoint_groups(self, context): + def list_endpoint_groups(self, request): """List all endpoint groups.""" refs = self.catalog_api.list_endpoint_groups() return EndpointGroupV3Controller.wrap_collection( - context, refs) + request.context_dict, refs) @controller.protected() - def list_endpoint_groups_for_project(self, context, project_id): + def list_endpoint_groups_for_project(self, request, project_id): """List all endpoint groups associated with a given project.""" return EndpointGroupV3Controller.wrap_collection( - context, + request.context_dict, self.catalog_api.get_endpoint_groups_for_project(project_id)) @controller.protected() def list_projects_associated_with_endpoint_group(self, - context, + request, endpoint_group_id): """List all projects associated with endpoint group.""" endpoint_group_refs = (self.catalog_api. @@ -544,18 +552,19 @@ class EndpointGroupV3Controller(controller.V3Controller): endpoint_group_ref['project_id']) if project: projects.append(project) - return resource.controllers.ProjectV3.wrap_collection(context, - projects) + return resource.controllers.ProjectV3.wrap_collection( + request.context_dict, projects) @controller.protected() def list_endpoints_associated_with_endpoint_group(self, - context, + request, endpoint_group_id): """List all the endpoints filtered by a specific endpoint group.""" filtered_endpoints = (self.catalog_api. get_endpoints_filtered_by_endpoint_group( endpoint_group_id)) - return EndpointV3.wrap_collection(context, filtered_endpoints) + return EndpointV3.wrap_collection(request.context_dict, + filtered_endpoints) @dependency.requires('catalog_api', 'resource_api') @@ -577,7 +586,7 @@ class ProjectEndpointGroupV3Controller(controller.V3Controller): project_id)) @controller.protected() - def get_endpoint_group_in_project(self, context, endpoint_group_id, + def get_endpoint_group_in_project(self, request, endpoint_group_id, project_id): """Retrieve the endpoint group associated with the id if exists.""" self.resource_api.get_project(project_id) @@ -585,10 +594,10 @@ class ProjectEndpointGroupV3Controller(controller.V3Controller): ref = self.catalog_api.get_endpoint_group_in_project( endpoint_group_id, project_id) return ProjectEndpointGroupV3Controller.wrap_member( - context, ref) + request.context_dict, ref) @controller.protected() - def add_endpoint_group_to_project(self, context, endpoint_group_id, + def add_endpoint_group_to_project(self, request, endpoint_group_id, project_id): """Create an association between an endpoint group and project.""" self.resource_api.get_project(project_id) @@ -597,7 +606,7 @@ class ProjectEndpointGroupV3Controller(controller.V3Controller): endpoint_group_id, project_id) @controller.protected() - def remove_endpoint_group_from_project(self, context, endpoint_group_id, + def remove_endpoint_group_from_project(self, request, endpoint_group_id, project_id): """Remove the endpoint group from associated project.""" self.resource_api.get_project(project_id) diff --git a/keystone/common/controller.py b/keystone/common/controller.py index b9e4038bb..30ab75cff 100644 --- a/keystone/common/controller.py +++ b/keystone/common/controller.py @@ -122,17 +122,23 @@ def protected(callback=None): """ def wrapper(f): @functools.wraps(f) - def inner(self, context, *args, **kwargs): - if 'is_admin' in context and context['is_admin']: + def inner(self, request, *args, **kwargs): + if request.context_dict.get('is_admin', False): LOG.warning(_LW('RBAC: Bypassing authorization')) elif callback is not None: prep_info = {'f_name': f.__name__, 'input_attr': kwargs} - callback(self, context, prep_info, *args, **kwargs) + callback(self, + request.context_dict, + prep_info, + *args, + **kwargs) else: action = 'identity:%s' % f.__name__ - creds = _build_policy_check_credentials(self, action, - context, kwargs) + creds = _build_policy_check_credentials(self, + action, + request.context_dict, + kwargs) policy_dict = {} @@ -149,11 +155,11 @@ def protected(callback=None): # TODO(henry-nash): Move this entire code to a member # method inside v3 Auth - if context.get('subject_token_id') is not None: + if request.context_dict.get('subject_token_id') is not None: token_ref = token_model.KeystoneToken( - token_id=context['subject_token_id'], + token_id=request.context_dict['subject_token_id'], token_data=self.token_provider_api.validate_token( - context['subject_token_id'])) + request.context_dict['subject_token_id'])) policy_dict.setdefault('target', {}) policy_dict['target'].setdefault(self.member_name, {}) policy_dict['target'][self.member_name]['user_id'] = ( @@ -178,7 +184,7 @@ def protected(callback=None): action, utils.flatten_dict(policy_dict)) LOG.debug('RBAC: Authorization granted') - return f(self, context, *args, **kwargs) + return f(self, request, *args, **kwargs) return inner return wrapper @@ -198,8 +204,8 @@ def filterprotected(*filters, **callback): """ def _filterprotected(f): @functools.wraps(f) - def wrapper(self, context, **kwargs): - if not context['is_admin']: + def wrapper(self, request, **kwargs): + if not request.context_dict['is_admin']: # The target dict for the policy check will include: # # - Any query filter parameters @@ -212,8 +218,9 @@ def filterprotected(*filters, **callback): target = dict() if filters: for item in filters: - if item in context['query_string']: - target[item] = context['query_string'][item] + if item in request.context_dict['query_string']: + i = request.context_dict['query_string'][item] + target[item] = i LOG.debug('RBAC: Adding query filter params (%s)', ( ', '.join(['%s=%s' % (item, target[item]) @@ -227,12 +234,15 @@ def filterprotected(*filters, **callback): prep_info = {'f_name': f.__name__, 'input_attr': kwargs, 'filter_attr': target} - callback['callback'](self, context, prep_info, **kwargs) + callback['callback'](self, + request.context_dict, + prep_info, + **kwargs) else: # No callback, so we are going to check the protection here action = 'identity:%s' % f.__name__ - creds = _build_policy_check_credentials(self, action, - context, kwargs) + creds = _build_policy_check_credentials( + self, action, request.context_dict, kwargs) # Add in any formal url parameters for key in kwargs: target[key] = kwargs[key] @@ -244,7 +254,7 @@ def filterprotected(*filters, **callback): LOG.debug('RBAC: Authorization granted') else: LOG.warning(_LW('RBAC: Bypassing authorization')) - return f(self, context, filters, **kwargs) + return f(self, request, filters, **kwargs) return wrapper return _filterprotected diff --git a/keystone/common/wsgi.py b/keystone/common/wsgi.py index 4071bca62..933d1b3f7 100644 --- a/keystone/common/wsgi.py +++ b/keystone/common/wsgi.py @@ -219,7 +219,7 @@ class Application(BaseApplication): params = self._normalize_dict(params) try: - result = method(req.context_dict, **params) + result = method(req, **params) except exception.Unauthorized as e: LOG.warning( _LW("Authorization failed. %(exception)s from " diff --git a/keystone/contrib/ec2/controllers.py b/keystone/contrib/ec2/controllers.py index c0f6067e0..966936d62 100644 --- a/keystone/contrib/ec2/controllers.py +++ b/keystone/contrib/ec2/controllers.py @@ -265,7 +265,7 @@ class Ec2ControllerCommon(object): class Ec2Controller(Ec2ControllerCommon, controller.V2Controller): @controller.v2_ec2_deprecated - def authenticate(self, context, credentials=None, ec2Credentials=None): + def authenticate(self, request, credentials=None, ec2Credentials=None): (user_ref, tenant_ref, metadata_ref, roles_ref, catalog_ref) = self._authenticate(credentials=credentials, ec2credentials=ec2Credentials) @@ -285,29 +285,29 @@ class Ec2Controller(Ec2ControllerCommon, controller.V2Controller): return token_data @controller.v2_ec2_deprecated - def get_credential(self, context, user_id, credential_id): - if not self._is_admin(context): - self._assert_identity(context, user_id) + def get_credential(self, request, user_id, credential_id): + if not self._is_admin(request.context_dict): + self._assert_identity(request.context_dict, user_id) return super(Ec2Controller, self).get_credential(user_id, credential_id) @controller.v2_ec2_deprecated - def get_credentials(self, context, user_id): - if not self._is_admin(context): - self._assert_identity(context, user_id) + def get_credentials(self, request, user_id): + if not self._is_admin(request.context_dict): + self._assert_identity(request.context_dict, user_id) return super(Ec2Controller, self).get_credentials(user_id) @controller.v2_ec2_deprecated - def create_credential(self, context, user_id, tenant_id): - if not self._is_admin(context): - self._assert_identity(context, user_id) - return super(Ec2Controller, self).create_credential(context, user_id, - tenant_id) + def create_credential(self, request, user_id, tenant_id): + if not self._is_admin(request.context_dict): + self._assert_identity(request.context_dict, user_id) + return super(Ec2Controller, self).create_credential( + request.context_dict, user_id, tenant_id) @controller.v2_ec2_deprecated - def delete_credential(self, context, user_id, credential_id): - if not self._is_admin(context): - self._assert_identity(context, user_id) + def delete_credential(self, request, user_id, credential_id): + if not self._is_admin(request.context_dict): + self._assert_identity(request.context_dict, user_id) self._assert_owner(user_id, credential_id) return super(Ec2Controller, self).delete_credential(user_id, credential_id) @@ -392,24 +392,27 @@ class Ec2ControllerV3(Ec2ControllerCommon, controller.V3Controller): return render_token_data_response(token_id, token_data) @controller.protected(callback=_check_credential_owner_and_user_id_match) - def ec2_get_credential(self, context, user_id, credential_id): + def ec2_get_credential(self, request, user_id, credential_id): ref = super(Ec2ControllerV3, self).get_credential(user_id, credential_id) - return Ec2ControllerV3.wrap_member(context, ref['credential']) + return Ec2ControllerV3.wrap_member(request.context_dict, + ref['credential']) @controller.protected() - def ec2_list_credentials(self, context, user_id): + def ec2_list_credentials(self, request, user_id): refs = super(Ec2ControllerV3, self).get_credentials(user_id) - return Ec2ControllerV3.wrap_collection(context, refs['credentials']) + return Ec2ControllerV3.wrap_collection(request.context_dict, + refs['credentials']) @controller.protected() - def ec2_create_credential(self, context, user_id, tenant_id): - ref = super(Ec2ControllerV3, self).create_credential(context, user_id, - tenant_id) - return Ec2ControllerV3.wrap_member(context, ref['credential']) + def ec2_create_credential(self, request, user_id, tenant_id): + ref = super(Ec2ControllerV3, self).create_credential( + request.context_dict, user_id, tenant_id) + return Ec2ControllerV3.wrap_member(request.context_dict, + ref['credential']) @controller.protected(callback=_check_credential_owner_and_user_id_match) - def ec2_delete_credential(self, context, user_id, credential_id): + def ec2_delete_credential(self, request, user_id, credential_id): return super(Ec2ControllerV3, self).delete_credential(user_id, credential_id) diff --git a/keystone/credential/controllers.py b/keystone/credential/controllers.py index 3ce16d92d..85dd60221 100644 --- a/keystone/credential/controllers.py +++ b/keystone/credential/controllers.py @@ -63,12 +63,12 @@ class CredentialV3(controller.V3Controller): @controller.protected() @validation.validated(schema.credential_create, 'credential') - def create_credential(self, context, credential): - trust_id = self._get_trust_id_for_request(context) + def create_credential(self, request, credential): + trust_id = self._get_trust_id_for_request(request.context_dict) ref = self._assign_unique_id(self._normalize_dict(credential), trust_id) ref = self.credential_api.create_credential(ref['id'], ref) - return CredentialV3.wrap_member(context, ref) + return CredentialV3.wrap_member(request.context_dict, ref) @staticmethod def _blob_to_json(ref): @@ -83,27 +83,27 @@ class CredentialV3(controller.V3Controller): return ref @controller.filterprotected('user_id', 'type') - def list_credentials(self, context, filters): - hints = CredentialV3.build_driver_hints(context, filters) + def list_credentials(self, request, filters): + hints = CredentialV3.build_driver_hints(request.context_dict, filters) refs = self.credential_api.list_credentials(hints) ret_refs = [self._blob_to_json(r) for r in refs] - return CredentialV3.wrap_collection(context, ret_refs, + return CredentialV3.wrap_collection(request.context_dict, ret_refs, hints=hints) @controller.protected() - def get_credential(self, context, credential_id): + def get_credential(self, request, credential_id): ref = self.credential_api.get_credential(credential_id) ret_ref = self._blob_to_json(ref) - return CredentialV3.wrap_member(context, ret_ref) + return CredentialV3.wrap_member(request.context_dict, ret_ref) @controller.protected() @validation.validated(schema.credential_update, 'credential') - def update_credential(self, context, credential_id, credential): + def update_credential(self, request, credential_id, credential): self._require_matching_id(credential_id, credential) ref = self.credential_api.update_credential(credential_id, credential) - return CredentialV3.wrap_member(context, ref) + return CredentialV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_credential(self, context, credential_id): + def delete_credential(self, request, credential_id): return self.credential_api.delete_credential(credential_id) diff --git a/keystone/endpoint_policy/controllers.py b/keystone/endpoint_policy/controllers.py index b96834dca..02dfbcf1c 100644 --- a/keystone/endpoint_policy/controllers.py +++ b/keystone/endpoint_policy/controllers.py @@ -50,7 +50,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): payload['resource_info']) @controller.protected() - def create_policy_association_for_endpoint(self, context, + def create_policy_association_for_endpoint(self, request, policy_id, endpoint_id): """Create an association between a policy and an endpoint.""" self.policy_api.get_policy(policy_id) @@ -59,7 +59,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): policy_id, endpoint_id=endpoint_id) @controller.protected() - def check_policy_association_for_endpoint(self, context, + def check_policy_association_for_endpoint(self, request, policy_id, endpoint_id): """Check an association between a policy and an endpoint.""" self.policy_api.get_policy(policy_id) @@ -68,7 +68,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): policy_id, endpoint_id=endpoint_id) @controller.protected() - def delete_policy_association_for_endpoint(self, context, + def delete_policy_association_for_endpoint(self, request, policy_id, endpoint_id): """Delete an association between a policy and an endpoint.""" self.policy_api.get_policy(policy_id) @@ -77,7 +77,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): policy_id, endpoint_id=endpoint_id) @controller.protected() - def create_policy_association_for_service(self, context, + def create_policy_association_for_service(self, request, policy_id, service_id): """Create an association between a policy and a service.""" self.policy_api.get_policy(policy_id) @@ -86,7 +86,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): policy_id, service_id=service_id) @controller.protected() - def check_policy_association_for_service(self, context, + def check_policy_association_for_service(self, request, policy_id, service_id): """Check an association between a policy and a service.""" self.policy_api.get_policy(policy_id) @@ -95,7 +95,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): policy_id, service_id=service_id) @controller.protected() - def delete_policy_association_for_service(self, context, + def delete_policy_association_for_service(self, request, policy_id, service_id): """Delete an association between a policy and a service.""" self.policy_api.get_policy(policy_id) @@ -105,7 +105,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): @controller.protected() def create_policy_association_for_region_and_service( - self, context, policy_id, service_id, region_id): + self, request, policy_id, service_id, region_id): """Create an association between a policy and region+service.""" self.policy_api.get_policy(policy_id) self.catalog_api.get_service(service_id) @@ -115,7 +115,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): @controller.protected() def check_policy_association_for_region_and_service( - self, context, policy_id, service_id, region_id): + self, request, policy_id, service_id, region_id): """Check an association between a policy and region+service.""" self.policy_api.get_policy(policy_id) self.catalog_api.get_service(service_id) @@ -125,7 +125,7 @@ class EndpointPolicyV3Controller(controller.V3Controller): @controller.protected() def delete_policy_association_for_region_and_service( - self, context, policy_id, service_id, region_id): + self, request, policy_id, service_id, region_id): """Delete an association between a policy and region+service.""" self.policy_api.get_policy(policy_id) self.catalog_api.get_service(service_id) @@ -134,14 +134,14 @@ class EndpointPolicyV3Controller(controller.V3Controller): policy_id, service_id=service_id, region_id=region_id) @controller.protected() - def get_policy_for_endpoint(self, context, endpoint_id): + def get_policy_for_endpoint(self, request, endpoint_id): """Get the effective policy for an endpoint.""" self.catalog_api.get_endpoint(endpoint_id) ref = self.endpoint_policy_api.get_policy_for_endpoint(endpoint_id) # NOTE(henry-nash): since the collection and member for this class is # set to endpoints, we have to handle wrapping this policy entity # ourselves. - self._add_self_referential_link(context, ref) + self._add_self_referential_link(request.context_dict, ref) return {'policy': ref} # NOTE(henry-nash): As in the catalog controller, we must ensure that the @@ -159,8 +159,9 @@ class EndpointPolicyV3Controller(controller.V3Controller): return super(EndpointPolicyV3Controller, cls).wrap_member(context, ref) @controller.protected() - def list_endpoints_for_policy(self, context, policy_id): + def list_endpoints_for_policy(self, request, policy_id): """List endpoints with the effective association to a policy.""" self.policy_api.get_policy(policy_id) refs = self.endpoint_policy_api.list_endpoints_for_policy(policy_id) - return EndpointPolicyV3Controller.wrap_collection(context, refs) + return EndpointPolicyV3Controller.wrap_collection(request.context_dict, + refs) diff --git a/keystone/federation/controllers.py b/keystone/federation/controllers.py index 6dc9de1ed..6ad0e3c98 100644 --- a/keystone/federation/controllers.py +++ b/keystone/federation/controllers.py @@ -91,35 +91,36 @@ class IdentityProvider(_ControllerBase): @controller.protected() @validation.validated(schema.identity_provider_create, 'identity_provider') - def create_identity_provider(self, context, idp_id, identity_provider): + def create_identity_provider(self, request, idp_id, identity_provider): identity_provider = self._normalize_dict(identity_provider) identity_provider.setdefault('enabled', False) idp_ref = self.federation_api.create_idp(idp_id, identity_provider) - response = IdentityProvider.wrap_member(context, idp_ref) + response = IdentityProvider.wrap_member(request.context_dict, idp_ref) return wsgi.render_response(body=response, status=('201', 'Created')) @controller.filterprotected('id', 'enabled') - def list_identity_providers(self, context, filters): - hints = self.build_driver_hints(context, filters) + def list_identity_providers(self, request, filters): + hints = self.build_driver_hints(request.context_dict, filters) ref = self.federation_api.list_idps(hints=hints) ref = [self.filter_params(x) for x in ref] - return IdentityProvider.wrap_collection(context, ref, hints=hints) + return IdentityProvider.wrap_collection(request.context_dict, + ref, hints=hints) @controller.protected() - def get_identity_provider(self, context, idp_id): + def get_identity_provider(self, request, idp_id): ref = self.federation_api.get_idp(idp_id) - return IdentityProvider.wrap_member(context, ref) + return IdentityProvider.wrap_member(request.context_dict, ref) @controller.protected() - def delete_identity_provider(self, context, idp_id): + def delete_identity_provider(self, request, idp_id): self.federation_api.delete_idp(idp_id) @controller.protected() @validation.validated(schema.identity_provider_update, 'identity_provider') - def update_identity_provider(self, context, idp_id, identity_provider): + def update_identity_provider(self, request, idp_id, identity_provider): identity_provider = self._normalize_dict(identity_provider) idp_ref = self.federation_api.update_idp(idp_id, identity_provider) - return IdentityProvider.wrap_member(context, idp_ref) + return IdentityProvider.wrap_member(request.context_dict, idp_ref) @dependency.requires('federation_api') @@ -179,33 +180,34 @@ class FederationProtocol(_ControllerBase): @controller.protected() @validation.validated(schema.protocol_create, 'protocol') - def create_protocol(self, context, idp_id, protocol_id, protocol): + def create_protocol(self, request, idp_id, protocol_id, protocol): ref = self._normalize_dict(protocol) ref = self.federation_api.create_protocol(idp_id, protocol_id, ref) - response = FederationProtocol.wrap_member(context, ref) + response = FederationProtocol.wrap_member(request.context_dict, ref) return wsgi.render_response(body=response, status=('201', 'Created')) @controller.protected() @validation.validated(schema.protocol_update, 'protocol') - def update_protocol(self, context, idp_id, protocol_id, protocol): + def update_protocol(self, request, idp_id, protocol_id, protocol): ref = self._normalize_dict(protocol) ref = self.federation_api.update_protocol(idp_id, protocol_id, protocol) - return FederationProtocol.wrap_member(context, ref) + return FederationProtocol.wrap_member(request.context_dict, ref) @controller.protected() - def get_protocol(self, context, idp_id, protocol_id): + def get_protocol(self, request, idp_id, protocol_id): ref = self.federation_api.get_protocol(idp_id, protocol_id) - return FederationProtocol.wrap_member(context, ref) + return FederationProtocol.wrap_member(request.context_dict, ref) @controller.protected() - def list_protocols(self, context, idp_id): + def list_protocols(self, request, idp_id): protocols_ref = self.federation_api.list_protocols(idp_id) protocols = list(protocols_ref) - return FederationProtocol.wrap_collection(context, protocols) + return FederationProtocol.wrap_collection(request.context_dict, + protocols) @controller.protected() - def delete_protocol(self, context, idp_id, protocol_id): + def delete_protocol(self, request, idp_id, protocol_id): self.federation_api.delete_protocol(idp_id, protocol_id) @@ -215,33 +217,34 @@ class MappingController(_ControllerBase): member_name = 'mapping' @controller.protected() - def create_mapping(self, context, mapping_id, mapping): + def create_mapping(self, request, mapping_id, mapping): ref = self._normalize_dict(mapping) utils.validate_mapping_structure(ref) mapping_ref = self.federation_api.create_mapping(mapping_id, ref) - response = MappingController.wrap_member(context, mapping_ref) + response = MappingController.wrap_member(request.context_dict, + mapping_ref) return wsgi.render_response(body=response, status=('201', 'Created')) @controller.protected() - def list_mappings(self, context): + def list_mappings(self, request): ref = self.federation_api.list_mappings() - return MappingController.wrap_collection(context, ref) + return MappingController.wrap_collection(request.context_dict, ref) @controller.protected() - def get_mapping(self, context, mapping_id): + def get_mapping(self, request, mapping_id): ref = self.federation_api.get_mapping(mapping_id) - return MappingController.wrap_member(context, ref) + return MappingController.wrap_member(request.context_dict, ref) @controller.protected() - def delete_mapping(self, context, mapping_id): + def delete_mapping(self, request, mapping_id): self.federation_api.delete_mapping(mapping_id) @controller.protected() - def update_mapping(self, context, mapping_id, mapping): + def update_mapping(self, request, mapping_id, mapping): mapping = self._normalize_dict(mapping) utils.validate_mapping_structure(mapping) mapping_ref = self.federation_api.update_mapping(mapping_id, mapping) - return MappingController.wrap_member(context, mapping_ref) + return MappingController.wrap_member(request.context_dict, mapping_ref) @dependency.requires('federation_api') @@ -281,7 +284,7 @@ class Auth(auth_controllers.Auth): return host - def federated_authentication(self, context, idp_id, protocol_id): + def federated_authentication(self, request, idp_id, protocol_id): """Authenticate from dedicated url endpoint. Build HTTP request body for federated authentication and inject @@ -298,34 +301,37 @@ class Auth(auth_controllers.Auth): } } - return self.authenticate_for_token(context, auth=auth) + return self.authenticate_for_token(request, auth=auth) - def federated_sso_auth(self, context, protocol_id): + def federated_sso_auth(self, request, protocol_id): try: remote_id_name = utils.get_remote_id_parameter(protocol_id) - remote_id = context['environment'][remote_id_name] + remote_id = request.context_dict['environment'][remote_id_name] except KeyError: msg = _('Missing entity ID from environment') LOG.error(msg) raise exception.Unauthorized(msg) - host = self._get_sso_origin_host(context) + host = self._get_sso_origin_host(request.context_dict) ref = self.federation_api.get_idp_from_remote_id(remote_id) # NOTE(stevemar): the returned object is a simple dict that # contains the idp_id and remote_id. identity_provider = ref['idp_id'] - res = self.federated_authentication(context, identity_provider, + res = self.federated_authentication(request, + identity_provider, protocol_id) token_id = res.headers['X-Subject-Token'] return self.render_html_response(host, token_id) - def federated_idp_specific_sso_auth(self, context, idp_id, protocol_id): - host = self._get_sso_origin_host(context) + def federated_idp_specific_sso_auth(self, request, idp_id, protocol_id): + host = self._get_sso_origin_host(request.context_dict) # NOTE(lbragstad): We validate that the Identity Provider actually # exists in the Mapped authentication plugin. - res = self.federated_authentication(context, idp_id, protocol_id) + res = self.federated_authentication(request, + idp_id, + protocol_id) token_id = res.headers['X-Subject-Token'] return self.render_html_response(host, token_id) @@ -378,13 +384,13 @@ class Auth(auth_controllers.Auth): ('X-auth-url', service_provider['auth_url'].encode('utf-8'))] @validation.validated(schema.saml_create, 'auth') - def create_saml_assertion(self, context, auth): + def create_saml_assertion(self, request, auth): """Exchange a scoped token for a SAML assertion. :param auth: Dictionary that contains a token and service provider ID :returns: SAML Assertion based on properties from the token """ - t = self._create_base_saml_assertion(context, auth) + t = self._create_base_saml_assertion(request.context_dict, auth) (response, service_provider) = t headers = self._build_response_headers(service_provider) @@ -423,17 +429,18 @@ class DomainV3(controller.V3Controller): self.get_member_from_driver = self.resource_api.get_domain @controller.protected() - def list_domains_for_groups(self, context): + def list_domains_for_groups(self, request): """List all domains available to an authenticated user's groups. :param context: request context :returns: list of accessible domains """ - auth_context = context['environment'][authorization.AUTH_CONTEXT_ENV] + env = request.context_dict['environment'] + auth_context = env[authorization.AUTH_CONTEXT_ENV] domains = self.assignment_api.list_domains_for_groups( auth_context['group_ids']) - return DomainV3.wrap_collection(context, domains) + return DomainV3.wrap_collection(request.context_dict, domains) @dependency.requires('assignment_api', 'resource_api') @@ -446,17 +453,19 @@ class ProjectAssignmentV3(controller.V3Controller): self.get_member_from_driver = self.resource_api.get_project @controller.protected() - def list_projects_for_groups(self, context): + def list_projects_for_groups(self, request): """List all projects available to an authenticated user's groups. :param context: request context :returns: list of accessible projects """ - auth_context = context['environment'][authorization.AUTH_CONTEXT_ENV] + env = request.context_dict['environment'] + auth_context = env[authorization.AUTH_CONTEXT_ENV] projects = self.assignment_api.list_projects_for_groups( auth_context['group_ids']) - return ProjectAssignmentV3.wrap_collection(context, projects) + return ProjectAssignmentV3.wrap_collection(request.context_dict, + projects) @dependency.requires('federation_api') @@ -471,37 +480,38 @@ class ServiceProvider(_ControllerBase): @controller.protected() @validation.validated(schema.service_provider_create, 'service_provider') - def create_service_provider(self, context, sp_id, service_provider): + def create_service_provider(self, request, sp_id, service_provider): service_provider = self._normalize_dict(service_provider) service_provider.setdefault('enabled', False) service_provider.setdefault('relay_state_prefix', CONF.saml.relay_state_prefix) sp_ref = self.federation_api.create_sp(sp_id, service_provider) - response = ServiceProvider.wrap_member(context, sp_ref) + response = ServiceProvider.wrap_member(request.context_dict, sp_ref) return wsgi.render_response(body=response, status=('201', 'Created')) @controller.filterprotected('id', 'enabled') - def list_service_providers(self, context, filters): - hints = self.build_driver_hints(context, filters) + def list_service_providers(self, request, filters): + hints = self.build_driver_hints(request.context_dict, filters) ref = self.federation_api.list_sps(hints=hints) ref = [self.filter_params(x) for x in ref] - return ServiceProvider.wrap_collection(context, ref, hints=hints) + return ServiceProvider.wrap_collection(request.context_dict, + ref, hints=hints) @controller.protected() - def get_service_provider(self, context, sp_id): + def get_service_provider(self, request, sp_id): ref = self.federation_api.get_sp(sp_id) - return ServiceProvider.wrap_member(context, ref) + return ServiceProvider.wrap_member(request.context_dict, ref) @controller.protected() - def delete_service_provider(self, context, sp_id): + def delete_service_provider(self, request, sp_id): self.federation_api.delete_sp(sp_id) @controller.protected() @validation.validated(schema.service_provider_update, 'service_provider') - def update_service_provider(self, context, sp_id, service_provider): + def update_service_provider(self, request, sp_id, service_provider): service_provider = self._normalize_dict(service_provider) sp_ref = self.federation_api.update_sp(sp_id, service_provider) - return ServiceProvider.wrap_member(context, sp_ref) + return ServiceProvider.wrap_member(request.context_dict, sp_ref) class SAMLMetadataV3(_ControllerBase): diff --git a/keystone/identity/controllers.py b/keystone/identity/controllers.py index 38c9cd182..30bbc286a 100644 --- a/keystone/identity/controllers.py +++ b/keystone/identity/controllers.py @@ -34,38 +34,39 @@ LOG = log.getLogger(__name__) class User(controller.V2Controller): @controller.v2_deprecated - def get_user(self, context, user_id): - self.assert_admin(context) + def get_user(self, request, user_id): + self.assert_admin(request.context_dict) ref = self.identity_api.get_user(user_id) return {'user': self.v3_to_v2_user(ref)} @controller.v2_deprecated - def get_users(self, context): + def get_users(self, request): # NOTE(termie): i can't imagine that this really wants all the data # about every single user in the system... - if 'name' in context['query_string']: + if 'name' in request.context_dict['query_string']: return self.get_user_by_name( - context, context['query_string'].get('name')) + request, + request.context_dict['query_string'].get('name')) - self.assert_admin(context) + self.assert_admin(request.context_dict) user_list = self.identity_api.list_users( CONF.identity.default_domain_id) return {'users': self.v3_to_v2_user(user_list)} @controller.v2_deprecated - def get_user_by_name(self, context, user_name): - self.assert_admin(context) + def get_user_by_name(self, request, user_name): + self.assert_admin(request.context_dict) ref = self.identity_api.get_user_by_name( user_name, CONF.identity.default_domain_id) return {'user': self.v3_to_v2_user(ref)} # CRUD extension @controller.v2_deprecated - def create_user(self, context, user): + def create_user(self, request, user): user = self._normalize_OSKSADM_password_on_request(user) user = self.normalize_username_in_request(user) user = self._normalize_dict(user) - self.assert_admin(context) + self.assert_admin(request.context_dict) if 'name' not in user or not user['name']: msg = _('Name field is required and cannot be empty') @@ -83,8 +84,8 @@ class User(controller.V2Controller): self.resource_api.ensure_default_domain_exists() # The manager layer will generate the unique ID for users - user_ref = self._normalize_domain_id(context, user.copy()) - initiator = notifications._get_request_audit_info(context) + user_ref = self._normalize_domain_id(request.context_dict, user.copy()) + initiator = notifications._get_request_audit_info(request.context_dict) new_user_ref = self.v3_to_v2_user( self.identity_api.create_user(user_ref, initiator)) @@ -94,10 +95,10 @@ class User(controller.V2Controller): return {'user': new_user_ref} @controller.v2_deprecated - def update_user(self, context, user_id, user): + def update_user(self, request, user_id, user): # NOTE(termie): this is really more of a patch than a put user = self.normalize_username_in_request(user) - self.assert_admin(context) + self.assert_admin(request.context_dict) if 'enabled' in user and not isinstance(user['enabled'], bool): msg = _('Enabled field should be a boolean') @@ -123,7 +124,7 @@ class User(controller.V2Controller): # user update. self.resource_api.get_project(default_project_id) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) user_ref = self.v3_to_v2_user( self.identity_api.update_user(user_id, user, initiator)) @@ -168,19 +169,19 @@ class User(controller.V2Controller): return {'user': user_ref} @controller.v2_deprecated - def delete_user(self, context, user_id): - self.assert_admin(context) - initiator = notifications._get_request_audit_info(context) + def delete_user(self, request, user_id): + self.assert_admin(request.context_dict) + initiator = notifications._get_request_audit_info(request.context_dict) self.identity_api.delete_user(user_id, initiator) @controller.v2_deprecated - def set_user_enabled(self, context, user_id, user): - return self.update_user(context, user_id, user) + def set_user_enabled(self, request, user_id, user): + return self.update_user(request, user_id, user) @controller.v2_deprecated - def set_user_password(self, context, user_id, user): + def set_user_password(self, request, user_id, user): user = self._normalize_OSKSADM_password_on_request(user) - return self.update_user(context, user_id, user) + return self.update_user(request, user_id, user) @staticmethod def _normalize_OSKSADM_password_on_request(ref): @@ -218,33 +219,32 @@ class UserV3(controller.V3Controller): @controller.protected() @validation.validated(schema.user_create, 'user') - def create_user(self, context, user): + def create_user(self, request, user): # The manager layer will generate the unique ID for users ref = self._normalize_dict(user) - ref = self._normalize_domain_id(context, ref) - initiator = notifications._get_request_audit_info(context) + ref = self._normalize_domain_id(request.context_dict, ref) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.identity_api.create_user(ref, initiator) - return UserV3.wrap_member(context, ref) + return UserV3.wrap_member(request.context_dict, ref) @controller.filterprotected('domain_id', 'enabled', 'name') - def list_users(self, context, filters): - hints = UserV3.build_driver_hints(context, filters) - refs = self.identity_api.list_users( - domain_scope=self._get_domain_id_for_list_request(context), - hints=hints) - return UserV3.wrap_collection(context, refs, hints=hints) + def list_users(self, request, filters): + hints = UserV3.build_driver_hints(request.context_dict, filters) + domain = self._get_domain_id_for_list_request(request.context_dict) + refs = self.identity_api.list_users(domain_scope=domain, hints=hints) + return UserV3.wrap_collection(request.context_dict, refs, hints=hints) @controller.filterprotected('domain_id', 'enabled', 'name', callback=_check_group_protection) - def list_users_in_group(self, context, filters, group_id): - hints = UserV3.build_driver_hints(context, filters) + def list_users_in_group(self, request, filters, group_id): + hints = UserV3.build_driver_hints(request.context_dict, filters) refs = self.identity_api.list_users_in_group(group_id, hints=hints) - return UserV3.wrap_collection(context, refs, hints=hints) + return UserV3.wrap_collection(request.context_dict, refs, hints=hints) @controller.protected() - def get_user(self, context, user_id): + def get_user(self, request, user_id): ref = self.identity_api.get_user(user_id) - return UserV3.wrap_member(context, ref) + return UserV3.wrap_member(request.context_dict, ref) def _update_user(self, context, user_id, user): self._require_matching_id(user_id, user) @@ -256,30 +256,30 @@ class UserV3(controller.V3Controller): @controller.protected() @validation.validated(schema.user_update, 'user') - def update_user(self, context, user_id, user): - return self._update_user(context, user_id, user) + def update_user(self, request, user_id, user): + return self._update_user(request.context_dict, user_id, user) @controller.protected(callback=_check_user_and_group_protection) - def add_user_to_group(self, context, user_id, group_id): - initiator = notifications._get_request_audit_info(context) + def add_user_to_group(self, request, user_id, group_id): + initiator = notifications._get_request_audit_info(request.context_dict) self.identity_api.add_user_to_group(user_id, group_id, initiator) @controller.protected(callback=_check_user_and_group_protection) - def check_user_in_group(self, context, user_id, group_id): + def check_user_in_group(self, request, user_id, group_id): return self.identity_api.check_user_in_group(user_id, group_id) @controller.protected(callback=_check_user_and_group_protection) - def remove_user_from_group(self, context, user_id, group_id): - initiator = notifications._get_request_audit_info(context) + def remove_user_from_group(self, request, user_id, group_id): + initiator = notifications._get_request_audit_info(request.context_dict) self.identity_api.remove_user_from_group(user_id, group_id, initiator) @controller.protected() - def delete_user(self, context, user_id): - initiator = notifications._get_request_audit_info(context) + def delete_user(self, request, user_id): + initiator = notifications._get_request_audit_info(request.context_dict) return self.identity_api.delete_user(user_id, initiator) @controller.protected() - def change_password(self, context, user_id, user): + def change_password(self, request, user_id, user): original_password = user.get('original_password') if original_password is None: raise exception.ValidationError(target='user', @@ -291,7 +291,7 @@ class UserV3(controller.V3Controller): attribute='password') try: self.identity_api.change_password( - context, user_id, original_password, password) + request.context_dict, user_id, original_password, password) except AssertionError: raise exception.Unauthorized() @@ -312,44 +312,43 @@ class GroupV3(controller.V3Controller): @controller.protected() @validation.validated(schema.group_create, 'group') - def create_group(self, context, group): + def create_group(self, request, group): # The manager layer will generate the unique ID for groups ref = self._normalize_dict(group) - ref = self._normalize_domain_id(context, ref) - initiator = notifications._get_request_audit_info(context) + ref = self._normalize_domain_id(request.context_dict, ref) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.identity_api.create_group(ref, initiator) - return GroupV3.wrap_member(context, ref) + return GroupV3.wrap_member(request.context_dict, ref) @controller.filterprotected('domain_id', 'name') - def list_groups(self, context, filters): - hints = GroupV3.build_driver_hints(context, filters) - refs = self.identity_api.list_groups( - domain_scope=self._get_domain_id_for_list_request(context), - hints=hints) - return GroupV3.wrap_collection(context, refs, hints=hints) + def list_groups(self, request, filters): + hints = GroupV3.build_driver_hints(request.context_dict, filters) + domain = self._get_domain_id_for_list_request(request.context_dict) + refs = self.identity_api.list_groups(domain_scope=domain, hints=hints) + return GroupV3.wrap_collection(request.context_dict, refs, hints=hints) @controller.filterprotected('name', callback=_check_user_protection) - def list_groups_for_user(self, context, filters, user_id): - hints = GroupV3.build_driver_hints(context, filters) + def list_groups_for_user(self, request, filters, user_id): + hints = GroupV3.build_driver_hints(request.context_dict, filters) refs = self.identity_api.list_groups_for_user(user_id, hints=hints) - return GroupV3.wrap_collection(context, refs, hints=hints) + return GroupV3.wrap_collection(request.context_dict, refs, hints=hints) @controller.protected() - def get_group(self, context, group_id): + def get_group(self, request, group_id): ref = self.identity_api.get_group(group_id) - return GroupV3.wrap_member(context, ref) + return GroupV3.wrap_member(request.context_dict, ref) @controller.protected() @validation.validated(schema.group_update, 'group') - def update_group(self, context, group_id, group): + def update_group(self, request, group_id, group): self._require_matching_id(group_id, group) self._require_matching_domain_id( group_id, group, self.identity_api.get_group) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.identity_api.update_group(group_id, group, initiator) - return GroupV3.wrap_member(context, ref) + return GroupV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_group(self, context, group_id): - initiator = notifications._get_request_audit_info(context) + def delete_group(self, request, group_id): + initiator = notifications._get_request_audit_info(request.context_dict) self.identity_api.delete_group(group_id, initiator) diff --git a/keystone/oauth1/controllers.py b/keystone/oauth1/controllers.py index 489bb4c7c..250ef2fa7 100644 --- a/keystone/oauth1/controllers.py +++ b/keystone/oauth1/controllers.py @@ -60,38 +60,38 @@ class ConsumerCrudV3(controller.V3Controller): @controller.protected() @validation.validated(schema.consumer_create, 'consumer') - def create_consumer(self, context, consumer): + def create_consumer(self, request, consumer): ref = self._assign_unique_id(self._normalize_dict(consumer)) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) consumer_ref = self.oauth_api.create_consumer(ref, initiator) - return ConsumerCrudV3.wrap_member(context, consumer_ref) + return ConsumerCrudV3.wrap_member(request.context_dict, consumer_ref) @controller.protected() @validation.validated(schema.consumer_update, 'consumer') - def update_consumer(self, context, consumer_id, consumer): + def update_consumer(self, request, consumer_id, consumer): self._require_matching_id(consumer_id, consumer) ref = self._normalize_dict(consumer) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.oauth_api.update_consumer(consumer_id, ref, initiator) - return ConsumerCrudV3.wrap_member(context, ref) + return ConsumerCrudV3.wrap_member(request.context_dict, ref) @controller.protected() - def list_consumers(self, context): + def list_consumers(self, request): ref = self.oauth_api.list_consumers() - return ConsumerCrudV3.wrap_collection(context, ref) + return ConsumerCrudV3.wrap_collection(request.context_dict, ref) @controller.protected() - def get_consumer(self, context, consumer_id): + def get_consumer(self, request, consumer_id): ref = self.oauth_api.get_consumer(consumer_id) - return ConsumerCrudV3.wrap_member(context, ref) + return ConsumerCrudV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_consumer(self, context, consumer_id): - user_token_ref = utils.get_token_ref(context) + def delete_consumer(self, request, consumer_id): + user_token_ref = utils.get_token_ref(request.context_dict) payload = {'user_id': user_token_ref.user_id, 'consumer_id': consumer_id} _emit_user_oauth_consumer_token_invalidate(payload) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) self.oauth_api.delete_consumer(consumer_id, initiator) @@ -110,33 +110,36 @@ class AccessTokenCrudV3(controller.V3Controller): ref['links']['self'] = cls.base_url(context, path) + '/' + ref['id'] @controller.protected() - def get_access_token(self, context, user_id, access_token_id): + def get_access_token(self, request, user_id, access_token_id): access_token = self.oauth_api.get_access_token(access_token_id) if access_token['authorizing_user_id'] != user_id: raise exception.NotFound() - access_token = self._format_token_entity(context, access_token) - return AccessTokenCrudV3.wrap_member(context, access_token) + access_token = self._format_token_entity(request.context_dict, + access_token) + return AccessTokenCrudV3.wrap_member(request.context_dict, + access_token) @controller.protected() - def list_access_tokens(self, context, user_id): - auth_context = context.get('environment', - {}).get('KEYSTONE_AUTH_CONTEXT', {}) + def list_access_tokens(self, request, user_id): + env = request.context_dict.get('environment', {}) + auth_context = env.get('KEYSTONE_AUTH_CONTEXT', {}) if auth_context.get('is_delegated_auth'): raise exception.Forbidden( _('Cannot list request tokens' ' with a token issued via delegation.')) refs = self.oauth_api.list_access_tokens(user_id) - formatted_refs = ([self._format_token_entity(context, x) + formatted_refs = ([self._format_token_entity(request.context_dict, x) for x in refs]) - return AccessTokenCrudV3.wrap_collection(context, formatted_refs) + return AccessTokenCrudV3.wrap_collection(request.context_dict, + formatted_refs) @controller.protected() - def delete_access_token(self, context, user_id, access_token_id): + def delete_access_token(self, request, user_id, access_token_id): access_token = self.oauth_api.get_access_token(access_token_id) consumer_id = access_token['consumer_id'] payload = {'user_id': user_id, 'consumer_id': consumer_id} _emit_user_oauth_consumer_token_invalidate(payload) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) return self.oauth_api.delete_access_token( user_id, access_token_id, initiator) @@ -170,17 +173,17 @@ class AccessTokenRolesV3(controller.V3Controller): member_name = 'role' @controller.protected() - def list_access_token_roles(self, context, user_id, access_token_id): + def list_access_token_roles(self, request, user_id, access_token_id): access_token = self.oauth_api.get_access_token(access_token_id) if access_token['authorizing_user_id'] != user_id: raise exception.NotFound() authed_role_ids = access_token['role_ids'] authed_role_ids = jsonutils.loads(authed_role_ids) refs = ([self._format_role_entity(x) for x in authed_role_ids]) - return AccessTokenRolesV3.wrap_collection(context, refs) + return AccessTokenRolesV3.wrap_collection(request.context_dict, refs) @controller.protected() - def get_access_token_role(self, context, user_id, + def get_access_token_role(self, request, user_id, access_token_id, role_id): access_token = self.oauth_api.get_access_token(access_token_id) if access_token['authorizing_user_id'] != user_id: @@ -190,7 +193,8 @@ class AccessTokenRolesV3(controller.V3Controller): for authed_role_id in authed_role_ids: if authed_role_id == role_id: role = self._format_role_entity(role_id) - return AccessTokenRolesV3.wrap_member(context, role) + return AccessTokenRolesV3.wrap_member(request.context_dict, + role) raise exception.RoleNotFound(role_id=role_id) def _format_role_entity(self, role_id): @@ -209,8 +213,8 @@ class OAuthControllerV3(controller.V3Controller): collection_name = 'not_used' member_name = 'not_used' - def create_request_token(self, context): - headers = context['headers'] + def create_request_token(self, request): + headers = request.context_dict['headers'] oauth_headers = oauth1.get_oauth_headers(headers) consumer_id = oauth_headers.get('oauth_consumer_key') requested_project_id = headers.get('Requested-Project-Id') @@ -226,7 +230,7 @@ class OAuthControllerV3(controller.V3Controller): self.resource_api.get_project(requested_project_id) self.oauth_api.get_consumer(consumer_id) - url = self.base_url(context, context['path']) + url = self.base_url(request.context_dict, request.context_dict['path']) req_headers = {'Requested-Project-Id': requested_project_id} req_headers.update(headers) @@ -236,7 +240,7 @@ class OAuthControllerV3(controller.V3Controller): h, b, s = request_verifier.create_request_token_response( url, http_method='POST', - body=context['query_string'], + body=request.context_dict['query_string'], headers=req_headers) if (not b) or int(s) > 399: @@ -244,7 +248,7 @@ class OAuthControllerV3(controller.V3Controller): raise exception.Unauthorized(message=msg) request_token_duration = CONF.oauth1.request_token_duration - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) token_ref = self.oauth_api.create_request_token(consumer_id, requested_project_id, request_token_duration, @@ -265,8 +269,8 @@ class OAuthControllerV3(controller.V3Controller): return response - def create_access_token(self, context): - headers = context['headers'] + def create_access_token(self, request): + headers = request.context_dict['headers'] oauth_headers = oauth1.get_oauth_headers(headers) consumer_id = oauth_headers.get('oauth_consumer_key') request_token_id = oauth_headers.get('oauth_token') @@ -293,7 +297,7 @@ class OAuthControllerV3(controller.V3Controller): if now > expires: raise exception.Unauthorized(_('Request token is expired')) - url = self.base_url(context, context['path']) + url = self.base_url(request.context_dict, request.context_dict['path']) access_verifier = oauth1.AccessTokenEndpoint( request_validator=validator.OAuthValidator(), @@ -301,7 +305,7 @@ class OAuthControllerV3(controller.V3Controller): h, b, s = access_verifier.create_access_token_response( url, http_method='POST', - body=context['query_string'], + body=request.context_dict['query_string'], headers=headers) params = oauth1.extract_non_oauth_params(b) if params: @@ -325,7 +329,7 @@ class OAuthControllerV3(controller.V3Controller): raise exception.Unauthorized(message=msg) access_token_duration = CONF.oauth1.access_token_duration - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) token_ref = self.oauth_api.create_access_token(request_token_id, access_token_duration, initiator) @@ -346,7 +350,7 @@ class OAuthControllerV3(controller.V3Controller): return response @controller.protected() - def authorize_request_token(self, context, request_token_id, roles): + def authorize_request_token(self, request, request_token_id, roles): """An authenticated user is going to authorize a request token. As a security precaution, the requested roles must match those in @@ -354,8 +358,8 @@ class OAuthControllerV3(controller.V3Controller): there is not another easy way to make sure the user knows which roles are being requested before authorizing. """ - auth_context = context.get('environment', - {}).get('KEYSTONE_AUTH_CONTEXT', {}) + env = request.context_dict.get('environment', {}) + auth_context = env.get('KEYSTONE_AUTH_CONTEXT', {}) if auth_context.get('is_delegated_auth'): raise exception.Forbidden( _('Cannot authorize a request token' @@ -377,7 +381,7 @@ class OAuthControllerV3(controller.V3Controller): authed_roles.add(role['id']) # verify the authorizing user has the roles - user_token = utils.get_token_ref(context) + user_token = utils.get_token_ref(request.context_dict) user_id = user_token.user_id project_id = req_token['requested_project_id'] user_roles = self.assignment_api.get_roles_for_user_and_project( diff --git a/keystone/policy/controllers.py b/keystone/policy/controllers.py index e6eb9bca8..67d229c0f 100644 --- a/keystone/policy/controllers.py +++ b/keystone/policy/controllers.py @@ -26,31 +26,32 @@ class PolicyV3(controller.V3Controller): @controller.protected() @validation.validated(schema.policy_create, 'policy') - def create_policy(self, context, policy): + def create_policy(self, request, policy): ref = self._assign_unique_id(self._normalize_dict(policy)) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.policy_api.create_policy(ref['id'], ref, initiator) - return PolicyV3.wrap_member(context, ref) + return PolicyV3.wrap_member(request.context_dict, ref) @controller.filterprotected('type') - def list_policies(self, context, filters): - hints = PolicyV3.build_driver_hints(context, filters) + def list_policies(self, request, filters): + hints = PolicyV3.build_driver_hints(request.context_dict, filters) refs = self.policy_api.list_policies(hints=hints) - return PolicyV3.wrap_collection(context, refs, hints=hints) + return PolicyV3.wrap_collection(request.context_dict, + refs, hints=hints) @controller.protected() - def get_policy(self, context, policy_id): + def get_policy(self, request, policy_id): ref = self.policy_api.get_policy(policy_id) - return PolicyV3.wrap_member(context, ref) + return PolicyV3.wrap_member(request.context_dict, ref) @controller.protected() @validation.validated(schema.policy_update, 'policy') - def update_policy(self, context, policy_id, policy): - initiator = notifications._get_request_audit_info(context) + def update_policy(self, request, policy_id, policy): + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.policy_api.update_policy(policy_id, policy, initiator) - return PolicyV3.wrap_member(context, ref) + return PolicyV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_policy(self, context, policy_id): - initiator = notifications._get_request_audit_info(context) + def delete_policy(self, request, policy_id): + initiator = notifications._get_request_audit_info(request.context_dict) return self.policy_api.delete_policy(policy_id, initiator) diff --git a/keystone/resource/controllers.py b/keystone/resource/controllers.py index e8dabc657..2f0613b05 100644 --- a/keystone/resource/controllers.py +++ b/keystone/resource/controllers.py @@ -36,12 +36,13 @@ CONF = cfg.CONF class Tenant(controller.V2Controller): @controller.v2_deprecated - def get_all_projects(self, context, **kw): + def get_all_projects(self, request, **kw): """Get a list of all tenants for an admin user.""" - self.assert_admin(context) + self.assert_admin(request.context_dict) - if 'name' in context['query_string']: - return self._get_project_by_name(context['query_string']['name']) + name = request.context_dict['query_string'].get('name') + if name: + return self._get_project_by_name(name) try: tenant_refs = self.resource_api.list_projects_in_domain( @@ -54,8 +55,8 @@ class Tenant(controller.V2Controller): for tenant_ref in tenant_refs if not tenant_ref.get('is_domain')] params = { - 'limit': context['query_string'].get('limit'), - 'marker': context['query_string'].get('marker'), + 'limit': request.context_dict['query_string'].get('limit'), + 'marker': request.context_dict['query_string'].get('marker'), } return self.format_project_list(tenant_refs, **params) @@ -67,9 +68,9 @@ class Tenant(controller.V2Controller): raise exception.ProjectNotFound(project_id) @controller.v2_deprecated - def get_project(self, context, tenant_id): + def get_project(self, request, tenant_id): # TODO(termie): this stuff should probably be moved to middleware - self.assert_admin(context) + self.assert_admin(request.context_dict) ref = self.resource_api.get_project(tenant_id) self._assert_not_is_domain_project(tenant_id, ref) return {'tenant': self.v3_to_v2_project(ref)} @@ -83,7 +84,7 @@ class Tenant(controller.V2Controller): # CRUD Extension @controller.v2_deprecated - def create_project(self, context, tenant): + def create_project(self, request, tenant): tenant_ref = self._normalize_dict(tenant) if 'name' not in tenant_ref or not tenant_ref['name']: @@ -95,37 +96,37 @@ class Tenant(controller.V2Controller): 'allowed in v2.') raise exception.ValidationError(message=msg) - self.assert_admin(context) + self.assert_admin(request.context_dict) self.resource_api.ensure_default_domain_exists() tenant_ref['id'] = tenant_ref.get('id', uuid.uuid4().hex) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) tenant = self.resource_api.create_project( tenant_ref['id'], - self._normalize_domain_id(context, tenant_ref), + self._normalize_domain_id(request.context_dict, tenant_ref), initiator) return {'tenant': self.v3_to_v2_project(tenant)} @controller.v2_deprecated - def update_project(self, context, tenant_id, tenant): - self.assert_admin(context) + def update_project(self, request, tenant_id, tenant): + self.assert_admin(request.context_dict) self._assert_not_is_domain_project(tenant_id) # Remove domain_id and is_domain if specified - a v2 api caller # should not be specifying that clean_tenant = tenant.copy() clean_tenant.pop('domain_id', None) clean_tenant.pop('is_domain', None) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) tenant_ref = self.resource_api.update_project( tenant_id, clean_tenant, initiator) return {'tenant': self.v3_to_v2_project(tenant_ref)} @controller.v2_deprecated - def delete_project(self, context, tenant_id): - self.assert_admin(context) + def delete_project(self, request, tenant_id): + self.assert_admin(request.context_dict) self._assert_not_is_domain_project(tenant_id) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) self.resource_api.delete_project(tenant_id, initiator) @@ -140,34 +141,35 @@ class DomainV3(controller.V3Controller): @controller.protected() @validation.validated(schema.domain_create, 'domain') - def create_domain(self, context, domain): + def create_domain(self, request, domain): ref = self._assign_unique_id(self._normalize_dict(domain)) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.resource_api.create_domain(ref['id'], ref, initiator) - return DomainV3.wrap_member(context, ref) + return DomainV3.wrap_member(request.context_dict, ref) @controller.filterprotected('enabled', 'name') - def list_domains(self, context, filters): - hints = DomainV3.build_driver_hints(context, filters) + def list_domains(self, request, filters): + hints = DomainV3.build_driver_hints(request.context_dict, filters) refs = self.resource_api.list_domains(hints=hints) - return DomainV3.wrap_collection(context, refs, hints=hints) + return DomainV3.wrap_collection(request.context_dict, + refs, hints=hints) @controller.protected() - def get_domain(self, context, domain_id): + def get_domain(self, request, domain_id): ref = self.resource_api.get_domain(domain_id) - return DomainV3.wrap_member(context, ref) + return DomainV3.wrap_member(request.context_dict, ref) @controller.protected() @validation.validated(schema.domain_update, 'domain') - def update_domain(self, context, domain_id, domain): + def update_domain(self, request, domain_id, domain): self._require_matching_id(domain_id, domain) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.resource_api.update_domain(domain_id, domain, initiator) - return DomainV3.wrap_member(context, ref) + return DomainV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_domain(self, context, domain_id): - initiator = notifications._get_request_audit_info(context) + def delete_domain(self, request, domain_id): + initiator = notifications._get_request_audit_info(request.context_dict) return self.resource_api.delete_domain(domain_id, initiator) @@ -177,7 +179,7 @@ class DomainConfigV3(controller.V3Controller): member_name = 'config' @controller.protected() - def create_domain_config(self, context, domain_id, config): + def create_domain_config(self, request, domain_id, config): self.resource_api.get_domain(domain_id) original_config = ( self.domain_config_api.get_config_with_sensitive_info(domain_id)) @@ -190,14 +192,14 @@ class DomainConfigV3(controller.V3Controller): status=('201', 'Created')) @controller.protected() - def get_domain_config(self, context, domain_id, group=None, option=None): + def get_domain_config(self, request, domain_id, group=None, option=None): self.resource_api.get_domain(domain_id) ref = self.domain_config_api.get_config(domain_id, group, option) return {self.member_name: ref} @controller.protected() def update_domain_config( - self, context, domain_id, config, group, option): + self, request, domain_id, config, group, option): self.resource_api.get_domain(domain_id) ref = self.domain_config_api.update_config( domain_id, config, group, option) @@ -215,12 +217,12 @@ class DomainConfigV3(controller.V3Controller): @controller.protected() def delete_domain_config( - self, context, domain_id, group=None, option=None): + self, request, domain_id, group=None, option=None): self.resource_api.get_domain(domain_id) self.domain_config_api.delete_config(domain_id, group, option) @controller.protected() - def get_domain_config_default(self, context, group=None, option=None): + def get_domain_config_default(self, request, group=None, option=None): ref = self.domain_config_api.get_config_default(group, option) return {self.member_name: ref} @@ -236,35 +238,36 @@ class ProjectV3(controller.V3Controller): @controller.protected() @validation.validated(schema.project_create, 'project') - def create_project(self, context, project): + def create_project(self, request, project): ref = self._assign_unique_id(self._normalize_dict(project)) if not ref.get('is_domain'): - ref = self._normalize_domain_id(context, ref) + ref = self._normalize_domain_id(request.context_dict, ref) # Our API requires that you specify the location in the hierarchy # unambiguously. This could be by parent_id or, if it is a top level # project, just by providing a domain_id. if not ref.get('parent_id'): ref['parent_id'] = ref.get('domain_id') - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) try: ref = self.resource_api.create_project(ref['id'], ref, initiator=initiator) except (exception.DomainNotFound, exception.ProjectNotFound) as e: raise exception.ValidationError(e) - return ProjectV3.wrap_member(context, ref) + return ProjectV3.wrap_member(request.context_dict, ref) @controller.filterprotected('domain_id', 'enabled', 'name', 'parent_id', 'is_domain') - def list_projects(self, context, filters): - hints = ProjectV3.build_driver_hints(context, filters) + def list_projects(self, request, filters): + hints = ProjectV3.build_driver_hints(request.context_dict, filters) # If 'is_domain' has not been included as a query, we default it to # False (which in query terms means '0' - if 'is_domain' not in context['query_string']: + if 'is_domain' not in request.context_dict['query_string']: hints.add_filter('is_domain', '0') refs = self.resource_api.list_projects(hints=hints) - return ProjectV3.wrap_collection(context, refs, hints=hints) + return ProjectV3.wrap_collection(request.context_dict, + refs, hints=hints) def _expand_project_ref(self, context, ref): params = context['query_string'] @@ -311,24 +314,24 @@ class ProjectV3(controller.V3Controller): ref['id']) @controller.protected() - def get_project(self, context, project_id): + def get_project(self, request, project_id): ref = self.resource_api.get_project(project_id) - self._expand_project_ref(context, ref) - return ProjectV3.wrap_member(context, ref) + self._expand_project_ref(request.context_dict, ref) + return ProjectV3.wrap_member(request.context_dict, ref) @controller.protected() @validation.validated(schema.project_update, 'project') - def update_project(self, context, project_id, project): + def update_project(self, request, project_id, project): self._require_matching_id(project_id, project) self._require_matching_domain_id( project_id, project, self.resource_api.get_project) - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) ref = self.resource_api.update_project(project_id, project, initiator=initiator) - return ProjectV3.wrap_member(context, ref) + return ProjectV3.wrap_member(request.context_dict, ref) @controller.protected() - def delete_project(self, context, project_id): - initiator = notifications._get_request_audit_info(context) + def delete_project(self, request, project_id): + initiator = notifications._get_request_audit_info(request.context_dict) return self.resource_api.delete_project(project_id, initiator=initiator) diff --git a/keystone/revoke/controllers.py b/keystone/revoke/controllers.py index 40151baea..d2ae7b93c 100644 --- a/keystone/revoke/controllers.py +++ b/keystone/revoke/controllers.py @@ -21,8 +21,8 @@ from keystone.i18n import _ @dependency.requires('revoke_api') class RevokeController(controller.V3Controller): @controller.protected() - def list_revoke_events(self, context): - since = context['query_string'].get('since') + def list_revoke_events(self, request): + since = request.context_dict['query_string'].get('since') last_fetch = None if since: try: @@ -37,8 +37,8 @@ class RevokeController(controller.V3Controller): 'links': { 'next': None, 'self': RevokeController.base_url( - context, - path=context['path']), + request.context_dict, + path=request.context_dict['path']), 'previous': None} } return response diff --git a/keystone/tests/unit/core.py b/keystone/tests/unit/core.py index 347dfb06c..b24ba508e 100644 --- a/keystone/tests/unit/core.py +++ b/keystone/tests/unit/core.py @@ -46,6 +46,7 @@ from keystone import auth from keystone.common import config from keystone.common import dependency from keystone.common.kvs import core as kvs_core +from keystone.common import request from keystone.common import sql from keystone import exception from keystone.identity.backends.ldap import common as ks_ldap @@ -574,6 +575,20 @@ class TestCase(BaseTestCase): def _policy_fixture(self): return ksfixtures.Policy(dirs.etc('policy.json'), self.config_fixture) + def make_request(self, path='/', **kwargs): + context = {} + + for k in ('is_admin', 'query_string'): + try: + context[k] = kwargs.pop(k) + except KeyError: + pass + + req = request.Request.blank(path=path, **kwargs) + req.context_dict.update(context) + + return req + def config_overrides(self): # NOTE(morganfainberg): enforce config_overrides can only ever be # called a single time. diff --git a/keystone/tests/unit/identity/test_controllers.py b/keystone/tests/unit/identity/test_controllers.py index ed2fe3ffb..e173b18b3 100644 --- a/keystone/tests/unit/identity/test_controllers.py +++ b/keystone/tests/unit/identity/test_controllers.py @@ -24,8 +24,6 @@ from keystone.tests.unit.ksfixtures import database CONF = cfg.CONF -_ADMIN_CONTEXT = {'is_admin': True, 'query_string': {}} - class UserTestCaseNoDefaultDomain(unit.TestCase): @@ -45,7 +43,7 @@ class UserTestCaseNoDefaultDomain(unit.TestCase): def test_get_users(self): # When list_users is done and there's no default domain, the result is # an empty list. - res = self.user_controller.get_users(_ADMIN_CONTEXT) + res = self.user_controller.get_users(self.make_request(is_admin=True)) self.assertEqual([], res['users']) def test_get_user_by_name(self): @@ -54,12 +52,14 @@ class UserTestCaseNoDefaultDomain(unit.TestCase): user_name = uuid.uuid4().hex self.assertRaises( exception.UserNotFound, - self.user_controller.get_user_by_name, _ADMIN_CONTEXT, user_name) + self.user_controller.get_user_by_name, + self.make_request(is_admin=True), user_name) def test_create_user(self): # When a user is created using the v2 controller and there's no default # domain, it doesn't fail with can't find domain (a default domain is # created) user = {'name': uuid.uuid4().hex} - self.user_controller.create_user(_ADMIN_CONTEXT, user) + self.user_controller.create_user(self.make_request(is_admin=True), + user) # If the above doesn't fail then this is successful. diff --git a/keystone/tests/unit/resource/test_controllers.py b/keystone/tests/unit/resource/test_controllers.py index b8f247c85..52ac6eba5 100644 --- a/keystone/tests/unit/resource/test_controllers.py +++ b/keystone/tests/unit/resource/test_controllers.py @@ -24,8 +24,6 @@ from keystone.tests.unit.ksfixtures import database CONF = cfg.CONF -_ADMIN_CONTEXT = {'is_admin': True, 'query_string': {}} - class TenantTestCaseNoDefaultDomain(unit.TestCase): @@ -45,7 +43,8 @@ class TenantTestCaseNoDefaultDomain(unit.TestCase): def test_get_all_projects(self): # When get_all_projects is done and there's no default domain, the # result is an empty list. - res = self.tenant_controller.get_all_projects(_ADMIN_CONTEXT) + req = self.make_request(is_admin=True) + res = self.tenant_controller.get_all_projects(req) self.assertEqual([], res['tenants']) def test_create_project(self): @@ -53,5 +52,6 @@ class TenantTestCaseNoDefaultDomain(unit.TestCase): # default domain, it doesn't fail with can't find domain (a default # domain is created) tenant = {'name': uuid.uuid4().hex} - self.tenant_controller.create_project(_ADMIN_CONTEXT, tenant) + self.tenant_controller.create_project(self.make_request(is_admin=True), + tenant) # If the above doesn't fail then this is successful. diff --git a/keystone/tests/unit/test_auth.py b/keystone/tests/unit/test_auth.py index a26ab1ead..8970cd993 100644 --- a/keystone/tests/unit/test_auth.py +++ b/keystone/tests/unit/test_auth.py @@ -84,10 +84,10 @@ class AuthTest(unit.TestCase): self.load_backends() self.load_fixtures(default_fixtures) - self.context_with_remote_user = {'environment': - {'REMOTE_USER': 'FOO', - 'AUTH_TYPE': 'Negotiate'}} - self.empty_context = {'environment': {}} + environ = {'REMOTE_USER': 'FOO', 'AUTH_TYPE': 'Negotiate'} + self.request_with_remote_user = self.make_request(environ=environ) + + self.empty_request = self.make_request() self.controller = token.controllers.Auth() @@ -163,20 +163,20 @@ class AuthBadRequests(AuthTest): """Verify sending empty json dict raises the right exception.""" self.assertRaises(exception.ValidationError, self.controller.authenticate, - {}, {}) + self.make_request(), {}) def test_authenticate_blank_auth(self): """Verify sending blank 'auth' raises the right exception.""" body_dict = _build_user_auth() self.assertRaises(exception.ValidationError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_invalid_auth_content(self): """Verify sending invalid 'auth' raises the right exception.""" self.assertRaises(exception.ValidationError, self.controller.authenticate, - {}, {'auth': 'abcd'}) + self.make_request(), {'auth': 'abcd'}) def test_authenticate_user_id_too_large(self): """Verify sending large 'userId' raises the right exception.""" @@ -184,14 +184,14 @@ class AuthBadRequests(AuthTest): password='foo2') self.assertRaises(exception.ValidationSizeError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_username_too_large(self): """Verify sending large 'username' raises the right exception.""" body_dict = _build_user_auth(username='0' * 65, password='foo2') self.assertRaises(exception.ValidationSizeError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_tenant_id_too_large(self): """Verify sending large 'tenantId' raises the right exception.""" @@ -199,7 +199,7 @@ class AuthBadRequests(AuthTest): tenant_id='0' * 65) self.assertRaises(exception.ValidationSizeError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_tenant_name_too_large(self): """Verify sending large 'tenantName' raises the right exception.""" @@ -207,14 +207,14 @@ class AuthBadRequests(AuthTest): tenant_name='0' * 65) self.assertRaises(exception.ValidationSizeError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_token_too_large(self): """Verify sending large 'token' raises the right exception.""" body_dict = _build_user_auth(token={'id': '0' * 8193}) self.assertRaises(exception.ValidationSizeError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_password_too_large(self): """Verify sending large 'password' raises the right exception.""" @@ -222,7 +222,7 @@ class AuthBadRequests(AuthTest): body_dict = _build_user_auth(username='FOO', password='0' * length) self.assertRaises(exception.ValidationSizeError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_fails_if_project_unsafe(self): """Verify authenticate to a project with unsafe name fails.""" @@ -236,7 +236,7 @@ class AuthBadRequests(AuthTest): self.resource_api.create_project(project['id'], project) self.assignment_api.add_role_to_user_and_project( self.user_foo['id'], project['id'], self.role_member['id']) - no_context = {} + empty_request = self.make_request() body_dict = _build_user_auth( username=self.user_foo['name'], @@ -244,7 +244,7 @@ class AuthBadRequests(AuthTest): tenant_name=project['name']) # Since name url restriction is off, we should be able to authenticate - self.controller.authenticate(no_context, body_dict) + self.controller.authenticate(empty_request, body_dict) # Set the name url restriction to strict and we should fail to # authenticate @@ -252,7 +252,7 @@ class AuthBadRequests(AuthTest): project_name_url_safe='strict') self.assertRaises(exception.Unauthorized, self.controller.authenticate, - no_context, body_dict) + empty_request, body_dict) class AuthWithToken(AuthTest): @@ -260,7 +260,8 @@ class AuthWithToken(AuthTest): """Verify getting an unscoped token with password creds.""" body_dict = _build_user_auth(username='FOO', password='foo2') - unscoped_token = self.controller.authenticate({}, body_dict) + unscoped_token = self.controller.authenticate(self.make_request(), + body_dict) self.assertNotIn('tenant', unscoped_token['access']['token']) def test_auth_invalid_token(self): @@ -269,7 +270,7 @@ class AuthWithToken(AuthTest): self.assertRaises( exception.Unauthorized, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_auth_bad_formatted_token(self): """Verify exception is raised if invalid token.""" @@ -277,18 +278,20 @@ class AuthWithToken(AuthTest): self.assertRaises( exception.ValidationError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_auth_unscoped_token_no_project(self): """Verify getting an unscoped token with an unscoped token.""" body_dict = _build_user_auth( username='FOO', password='foo2') - unscoped_token = self.controller.authenticate({}, body_dict) + unscoped_token = self.controller.authenticate(self.make_request(), + body_dict) body_dict = _build_user_auth( token=unscoped_token["access"]["token"]) - unscoped_token_2 = self.controller.authenticate({}, body_dict) + unscoped_token_2 = self.controller.authenticate(self.make_request(), + body_dict) self.assertEqualTokens(unscoped_token, unscoped_token_2) @@ -303,12 +306,14 @@ class AuthWithToken(AuthTest): body_dict = _build_user_auth( username='FOO', password='foo2') - unscoped_token = self.controller.authenticate({}, body_dict) + unscoped_token = self.controller.authenticate(self.make_request(), + body_dict) # Get a token on BAR tenant using the unscoped token body_dict = _build_user_auth( token=unscoped_token["access"]["token"], tenant_name="BAR") - scoped_token = self.controller.authenticate({}, body_dict) + scoped_token = self.controller.authenticate(self.make_request(), + body_dict) tenant = scoped_token["access"]["token"]["tenant"] roles = scoped_token["access"]["metadata"]["roles"] @@ -330,7 +335,7 @@ class AuthWithToken(AuthTest): e = self.assertRaises( exception.Unauthorized, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) # explicitly verify that the error message shows that a *name* is # found where an *ID* is expected self.assertIn( @@ -352,7 +357,7 @@ class AuthWithToken(AuthTest): e = self.assertRaises( exception.Unauthorized, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) # explicitly verify that the error message details above have been # suppressed. self.assertNotIn( @@ -384,7 +389,8 @@ class AuthWithToken(AuthTest): password='foo2', tenant_name="BAR") - scoped_token = self.controller.authenticate({}, body_dict) + scoped_token = self.controller.authenticate(self.make_request(), + body_dict) tenant = scoped_token["access"]["token"]["tenant"] roles = scoped_token["access"]["metadata"]["roles"] @@ -394,7 +400,7 @@ class AuthWithToken(AuthTest): def test_belongs_to_no_tenant(self): r = self.controller.authenticate( - {}, + self.make_request(), auth={ 'passwordCredentials': { 'username': self.user_foo['name'], @@ -405,7 +411,8 @@ class AuthWithToken(AuthTest): self.assertRaises( exception.Unauthorized, self.controller.validate_token, - dict(is_admin=True, query_string={'belongsTo': 'BAR'}), + self.make_request(is_admin=True, + query_string={'belongsTo': 'BAR'}), token_id=unscoped_token_id) def test_belongs_to(self): @@ -414,26 +421,28 @@ class AuthWithToken(AuthTest): password='foo2', tenant_name="BAR") - scoped_token = self.controller.authenticate({}, body_dict) + scoped_token = self.controller.authenticate(self.make_request(), + body_dict) scoped_token_id = scoped_token['access']['token']['id'] self.assertRaises( exception.Unauthorized, self.controller.validate_token, - dict(is_admin=True, query_string={'belongsTo': 'me'}), + self.make_request(is_admin=True, query_string={'belongsTo': 'me'}), token_id=scoped_token_id) self.assertRaises( exception.Unauthorized, self.controller.validate_token, - dict(is_admin=True, query_string={'belongsTo': 'BAR'}), + self.make_request(is_admin=True, + query_string={'belongsTo': 'BAR'}), token_id=scoped_token_id) def test_token_auth_with_binding(self): self.config_fixture.config(group='token', bind=['kerberos']) body_dict = _build_user_auth() unscoped_token = self.controller.authenticate( - self.context_with_remote_user, body_dict) + self.request_with_remote_user, body_dict) # the token should have bind information in it bind = unscoped_token['access']['token']['bind'] @@ -447,11 +456,11 @@ class AuthWithToken(AuthTest): self.assertRaises( exception.Unauthorized, self.controller.authenticate, - self.empty_context, body_dict) + self.empty_request, body_dict) # using token with remote user context succeeds scoped_token = self.controller.authenticate( - self.context_with_remote_user, body_dict) + self.request_with_remote_user, body_dict) # the bind information should be carried over from the original token bind = scoped_token['access']['token']['bind'] @@ -466,34 +475,31 @@ class AuthWithToken(AuthTest): self.role_api.create_role(role_one['id'], role_one) self.assignment_api.add_role_to_user_and_project( self.user_foo['id'], project1['id'], role_one['id']) - no_context = {} # Get a scoped token for the tenant body_dict = _build_user_auth( username=self.user_foo['name'], password=self.user_foo['password'], tenant_name=project1['name']) - token = self.controller.authenticate(no_context, body_dict) + token = self.controller.authenticate(self.empty_request, body_dict) # Ensure it is valid token_id = token['access']['token']['id'] - self.controller.validate_token( - dict(is_admin=True, query_string={}), - token_id=token_id) + self.controller.validate_token(self.make_request(is_admin=True), + token_id=token_id) # Delete the role, which should invalidate the token - role_controller.delete_role( - dict(is_admin=True, query_string={}), role_one['id']) + role_controller.delete_role(self.make_request(is_admin=True), + role_one['id']) # Check the token is now invalid self.assertRaises( exception.TokenNotFound, self.controller.validate_token, - dict(is_admin=True, query_string={}), + self.make_request(is_admin=True), token_id=token_id) def test_deleting_role_assignment_does_not_revoke_unscoped_token(self): - no_context = {} - admin_context = dict(is_admin=True, query_string={}) + admin_request = self.make_request(is_admin=True) project = unit.new_project_ref( domain_id=CONF.identity.default_domain_id) @@ -504,13 +510,14 @@ class AuthWithToken(AuthTest): self.user_foo['id'], project['id'], role['id']) # Get an unscoped token. - token = self.controller.authenticate(no_context, _build_user_auth( - username=self.user_foo['name'], - password=self.user_foo['password'])) + token = self.controller.authenticate( + self.make_request(), + _build_user_auth(username=self.user_foo['name'], + password=self.user_foo['password'])) token_id = token['access']['token']['id'] # Ensure it is valid - self.controller.validate_token(admin_context, token_id=token_id) + self.controller.validate_token(admin_request, token_id=token_id) # Delete the role assignment, which should not invalidate the token, # because we're not consuming it with just an unscoped token. @@ -518,23 +525,23 @@ class AuthWithToken(AuthTest): self.user_foo['id'], project['id'], role['id']) # Ensure it is still valid - self.controller.validate_token(admin_context, token_id=token_id) + self.controller.validate_token(admin_request, token_id=token_id) def test_only_original_audit_id_is_kept(self): - context = {} - def get_audit_ids(token): return token['access']['token']['audit_ids'] # get a token body_dict = _build_user_auth(username='FOO', password='foo2') - unscoped_token = self.controller.authenticate(context, body_dict) + unscoped_token = self.controller.authenticate(self.make_request(), + body_dict) starting_audit_id = get_audit_ids(unscoped_token)[0] self.assertIsNotNone(starting_audit_id) # get another token to ensure the correct parent audit_id is set body_dict = _build_user_auth(token=unscoped_token["access"]["token"]) - unscoped_token_2 = self.controller.authenticate(context, body_dict) + unscoped_token_2 = self.controller.authenticate(self.make_request(), + body_dict) audit_ids = get_audit_ids(unscoped_token_2) self.assertThat(audit_ids, matchers.HasLength(2)) self.assertThat(audit_ids[-1], matchers.Equals(starting_audit_id)) @@ -542,24 +549,26 @@ class AuthWithToken(AuthTest): # get another token from token 2 and ensure the correct parent # audit_id is set body_dict = _build_user_auth(token=unscoped_token_2["access"]["token"]) - unscoped_token_3 = self.controller.authenticate(context, body_dict) + unscoped_token_3 = self.controller.authenticate(self.make_request(), + body_dict) audit_ids = get_audit_ids(unscoped_token_3) self.assertThat(audit_ids, matchers.HasLength(2)) self.assertThat(audit_ids[-1], matchers.Equals(starting_audit_id)) def test_revoke_by_audit_chain_id_original_token(self): self.config_fixture.config(group='token', revoke_by_id=False) - context = {} # get a token body_dict = _build_user_auth(username='FOO', password='foo2') - unscoped_token = self.controller.authenticate(context, body_dict) + unscoped_token = self.controller.authenticate(self.make_request(), + body_dict) token_id = unscoped_token['access']['token']['id'] self.time_fixture.advance_time_seconds(1) # get a second token body_dict = _build_user_auth(token=unscoped_token["access"]["token"]) - unscoped_token_2 = self.controller.authenticate(context, body_dict) + unscoped_token_2 = self.controller.authenticate(self.make_request(), + body_dict) token_2_id = unscoped_token_2['access']['token']['id'] self.time_fixture.advance_time_seconds(1) @@ -574,17 +583,18 @@ class AuthWithToken(AuthTest): def test_revoke_by_audit_chain_id_chained_token(self): self.config_fixture.config(group='token', revoke_by_id=False) - context = {} # get a token body_dict = _build_user_auth(username='FOO', password='foo2') - unscoped_token = self.controller.authenticate(context, body_dict) + unscoped_token = self.controller.authenticate(self.make_request(), + body_dict) token_id = unscoped_token['access']['token']['id'] self.time_fixture.advance_time_seconds(1) # get a second token body_dict = _build_user_auth(token=unscoped_token["access"]["token"]) - unscoped_token_2 = self.controller.authenticate(context, body_dict) + unscoped_token_2 = self.controller.authenticate(self.make_request(), + body_dict) token_2_id = unscoped_token_2['access']['token']['id'] self.time_fixture.advance_time_seconds(1) @@ -616,7 +626,7 @@ class FernetAuthWithToken(AuthWithToken): body_dict = _build_user_auth() self.assertRaises(exception.NotImplemented, self.controller.authenticate, - self.context_with_remote_user, + self.request_with_remote_user, body_dict) def test_revoke_with_no_audit_info(self): @@ -635,7 +645,7 @@ class AuthWithPasswordCredentials(AuthTest): self.assertRaises( exception.Unauthorized, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_auth_valid_user_invalid_password(self): """Verify exception is raised if invalid password.""" @@ -645,7 +655,7 @@ class AuthWithPasswordCredentials(AuthTest): self.assertRaises( exception.Unauthorized, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_auth_empty_password(self): """Verify exception is raised if empty password.""" @@ -655,7 +665,7 @@ class AuthWithPasswordCredentials(AuthTest): self.assertRaises( exception.Unauthorized, self.controller.authenticate, - {}, body_dict) + self.empty_request, body_dict) def test_auth_no_password(self): """Verify exception is raised if empty password.""" @@ -663,14 +673,14 @@ class AuthWithPasswordCredentials(AuthTest): self.assertRaises( exception.ValidationError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_blank_password_credentials(self): """Sending empty dict as passwordCredentials raises 400 Bad Requset.""" body_dict = {'passwordCredentials': {}, 'tenantName': 'demo'} self.assertRaises(exception.ValidationError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_authenticate_no_username(self): """Verify skipping username raises the right exception.""" @@ -678,13 +688,13 @@ class AuthWithPasswordCredentials(AuthTest): tenant_name="demo") self.assertRaises(exception.ValidationError, self.controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_bind_without_remote_user(self): self.config_fixture.config(group='token', bind=['kerberos']) body_dict = _build_user_auth(username='FOO', password='foo2', tenant_name='BAR') - token = self.controller.authenticate({}, body_dict) + token = self.controller.authenticate(self.make_request(), body_dict) self.assertNotIn('bind', token['access']['token']) def test_change_default_domain_id(self): @@ -715,7 +725,7 @@ class AuthWithPasswordCredentials(AuthTest): password=new_user['password']) # The test is successful if this doesn't raise, so no need to assert. - self.controller.authenticate({}, body_dict) + self.controller.authenticate(self.make_request(), body_dict) class AuthWithRemoteUser(AuthTest): @@ -725,11 +735,11 @@ class AuthWithRemoteUser(AuthTest): username='FOO', password='foo2') local_token = self.controller.authenticate( - {}, body_dict) + self.make_request(), body_dict) body_dict = _build_user_auth() remote_token = self.controller.authenticate( - self.context_with_remote_user, body_dict) + self.request_with_remote_user, body_dict) self.assertEqualTokens(local_token, remote_token, enforce_audit_ids=False) @@ -739,7 +749,7 @@ class AuthWithRemoteUser(AuthTest): self.assertRaises( exception.ValidationError, self.controller.authenticate, - {'REMOTE_USER': 'FOO'}, + self.make_request(environ={'REMOTE_USER': 'FOO'}), None) def test_scoped_remote_authn(self): @@ -749,12 +759,12 @@ class AuthWithRemoteUser(AuthTest): password='foo2', tenant_name='BAR') local_token = self.controller.authenticate( - {}, body_dict) + self.make_request(), body_dict) body_dict = _build_user_auth( tenant_name='BAR') remote_token = self.controller.authenticate( - self.context_with_remote_user, body_dict) + self.request_with_remote_user, body_dict) self.assertEqualTokens(local_token, remote_token, enforce_audit_ids=False) @@ -766,11 +776,11 @@ class AuthWithRemoteUser(AuthTest): password='two2', tenant_name='BAZ') local_token = self.controller.authenticate( - {}, body_dict) + self.make_request(), body_dict) body_dict = _build_user_auth(tenant_name='BAZ') remote_token = self.controller.authenticate( - {'environment': {'REMOTE_USER': 'TWO'}}, body_dict) + self.make_request(environ={'REMOTE_USER': 'TWO'}), body_dict) self.assertEqualTokens(local_token, remote_token, enforce_audit_ids=False) @@ -781,20 +791,20 @@ class AuthWithRemoteUser(AuthTest): self.assertRaises( exception.Unauthorized, self.controller.authenticate, - {'environment': {'REMOTE_USER': uuid.uuid4().hex}}, + self.make_request(environ={'REMOTE_USER': uuid.uuid4().hex}), body_dict) def test_bind_with_kerberos(self): self.config_fixture.config(group='token', bind=['kerberos']) body_dict = _build_user_auth(tenant_name="BAR") - token = self.controller.authenticate(self.context_with_remote_user, + token = self.controller.authenticate(self.request_with_remote_user, body_dict) self.assertEqual('FOO', token['access']['token']['bind']['kerberos']) def test_bind_without_config_opt(self): self.config_fixture.config(group='token', bind=['x509']) body_dict = _build_user_auth(tenant_name='BAR') - token = self.controller.authenticate(self.context_with_remote_user, + token = self.controller.authenticate(self.request_with_remote_user, body_dict) self.assertNotIn('bind', token['access']['token']) @@ -824,7 +834,7 @@ class AuthWithTrust(AuthTest): super(AuthWithTrust, self).config_overrides() self.config_fixture.config(group='trust', enabled=True) - def _create_auth_context(self, token_id): + def _create_auth_request(self, token_id): token_ref = token_model.KeystoneToken( token_id=token_id, token_data=self.token_provider_api.validate_token(token_id)) @@ -834,31 +844,35 @@ class AuthWithTrust(AuthTest): # variables wsgi.url_scheme, SERVER_NAME, SERVER_PORT, and SCRIPT_NAME. # We have to set them in the context so the base url can be constructed # accordingly. - return {'environment': {authorization.AUTH_CONTEXT_ENV: auth_context, - 'wsgi.url_scheme': 'http', - 'SCRIPT_NAME': '/v3', - 'SERVER_PORT': '80', - 'SERVER_NAME': HOST}, - 'token_id': token_id, - 'host_url': HOST_URL} + environ = {authorization.AUTH_CONTEXT_ENV: auth_context, + 'wsgi.url_scheme': 'http', + 'HTTP_HOST': HOST_URL, + 'SCRIPT_NAME': '/v3', + 'SERVER_PORT': '80', + 'SERVER_NAME': HOST} + + req = self.make_request(environ=environ) + req.context_dict['token_id'] = token_id + + return req def create_trust(self, trust_data, trustor_name, expires_at=None, impersonation=True): username = trustor_name password = 'foo2' unscoped_token = self.get_unscoped_token(username, password) - context = self._create_auth_context( + request = self._create_auth_request( unscoped_token['access']['token']['id']) trust_data_copy = copy.deepcopy(trust_data) trust_data_copy['expires_at'] = expires_at trust_data_copy['impersonation'] = impersonation return self.trust_controller.create_trust( - context, trust=trust_data_copy)['trust'] + request, trust=trust_data_copy)['trust'] def get_unscoped_token(self, username, password='foo2'): body_dict = _build_user_auth(username=username, password=password) - return self.controller.authenticate({}, body_dict) + return self.controller.authenticate(self.make_request(), body_dict) def build_v2_token_request(self, username, password, trust, tenant_id=None): @@ -873,7 +887,7 @@ class AuthWithTrust(AuthTest): def test_create_trust_bad_data_fails(self): unscoped_token = self.get_unscoped_token(self.trustor['name']) - context = self._create_auth_context( + request = self._create_auth_request( unscoped_token['access']['token']['id']) bad_sample_data = {'trustor_user_id': self.trustor['id'], 'project_id': self.tenant_bar['id'], @@ -881,15 +895,16 @@ class AuthWithTrust(AuthTest): self.assertRaises(exception.ValidationError, self.trust_controller.create_trust, - context, trust=bad_sample_data) + request, trust=bad_sample_data) def test_create_trust_no_roles(self): unscoped_token = self.get_unscoped_token(self.trustor['name']) - context = {'token_id': unscoped_token['access']['token']['id']} + req = self.make_request() + req.context_dict['token_id'] = unscoped_token['access']['token']['id'] self.sample_data['roles'] = [] self.assertRaises(exception.Forbidden, self.trust_controller.create_trust, - context, trust=self.sample_data) + req, trust=self.sample_data) def test_create_trust(self): expires_at = (timeutils.utcnow() + @@ -932,12 +947,12 @@ class AuthWithTrust(AuthTest): Also, token can be generated with that trust. """ unscoped_token = self.get_unscoped_token(self.trustor['name']) - context = self._create_auth_context( + request = self._create_auth_request( unscoped_token['access']['token']['id']) self.sample_data['project_id'] = None self.sample_data['roles'] = [] new_trust = self.trust_controller.create_trust( - context, trust=self.sample_data)['trust'] + request, trust=self.sample_data)['trust'] self.assertEqual(self.trustor['id'], new_trust['trustor_user_id']) self.assertEqual(self.trustee['id'], new_trust['trustee_user_id']) self.assertIs(new_trust['impersonation'], True) @@ -947,11 +962,11 @@ class AuthWithTrust(AuthTest): def test_get_trust(self): unscoped_token = self.get_unscoped_token(self.trustor['name']) - context = self._create_auth_context( + request = self._create_auth_request( unscoped_token['access']['token']['id']) new_trust = self.trust_controller.create_trust( - context, trust=self.sample_data)['trust'] - trust = self.trust_controller.get_trust(context, + request, trust=self.sample_data)['trust'] + trust = self.trust_controller.get_trust(request, new_trust['id'])['trust'] self.assertEqual(self.trustor['id'], trust['trustor_user_id']) self.assertEqual(self.trustee['id'], trust['trustee_user_id']) @@ -962,14 +977,14 @@ class AuthWithTrust(AuthTest): def test_get_trust_without_auth_context(self): """Verify a trust cannot be retrieved if auth context is missing.""" unscoped_token = self.get_unscoped_token(self.trustor['name']) - context = self._create_auth_context( + request = self._create_auth_request( unscoped_token['access']['token']['id']) new_trust = self.trust_controller.create_trust( - context, trust=self.sample_data)['trust'] + request, trust=self.sample_data)['trust'] # Delete the auth context before calling get_trust(). - del context['environment'][authorization.AUTH_CONTEXT_ENV] + del request.context_dict['environment'][authorization.AUTH_CONTEXT_ENV] self.assertRaises(exception.Forbidden, - self.trust_controller.get_trust, context, + self.trust_controller.get_trust, request, new_trust['id']) def test_create_trust_no_impersonation(self): @@ -995,7 +1010,7 @@ class AuthWithTrust(AuthTest): new_trust = self.create_trust(self.sample_data, self.trustor['name']) request_body = self.build_v2_token_request('FOO', 'foo2', new_trust) self.assertRaises(exception.Forbidden, self.controller.authenticate, - {}, request_body) + self.make_request(), request_body) def test_token_from_trust_wrong_project_fails(self): for assigned_role in self.assigned_roles: @@ -1005,11 +1020,12 @@ class AuthWithTrust(AuthTest): request_body = self.build_v2_token_request('TWO', 'two2', new_trust, self.tenant_baz['id']) self.assertRaises(exception.Forbidden, self.controller.authenticate, - {}, request_body) + self.make_request(), request_body) def fetch_v2_token_from_trust(self, trust): request_body = self.build_v2_token_request('TWO', 'two2', trust) - auth_response = self.controller.authenticate({}, request_body) + auth_response = self.controller.authenticate(self.make_request(), + request_body) return auth_response def fetch_v3_token_from_trust(self, trust, trustee): @@ -1029,10 +1045,8 @@ class AuthWithTrust(AuthTest): } } } - auth_response = (self.auth_v3_controller.authenticate_for_token - ({'environment': {}, - 'query_string': {}}, - v3_password_data)) + auth_response = self.auth_v3_controller.authenticate_for_token( + self.make_request(), v3_password_data) token = auth_response.headers['X-Subject-Token'] v3_req_with_trust = { @@ -1041,10 +1055,8 @@ class AuthWithTrust(AuthTest): "token": {"id": token}}, "scope": { "OS-TRUST:trust": {"id": trust['id']}}} - token_auth_response = (self.auth_v3_controller.authenticate_for_token - ({'environment': {}, - 'query_string': {}}, - v3_req_with_trust)) + token_auth_response = self.auth_v3_controller.authenticate_for_token( + self.make_request(), v3_req_with_trust) return token_auth_response def test_create_v3_token_from_trust(self): @@ -1075,8 +1087,7 @@ class AuthWithTrust(AuthTest): self.assertRaises( exception.Forbidden, self.auth_v3_controller.authenticate_for_token, - {'environment': {}, - 'query_string': {}}, v3_token_data) + self.make_request(), v3_token_data) def test_token_from_trust(self): new_trust = self.create_trust(self.sample_data, self.trustor['name']) @@ -1111,12 +1122,12 @@ class AuthWithTrust(AuthTest): tenant_id=self.tenant_bar['id']) self.assertRaises( exception.Unauthorized, - self.controller.authenticate, {}, request_body) + self.controller.authenticate, self.make_request(), request_body) def test_delete_trust_revokes_token(self): unscoped_token = self.get_unscoped_token(self.trustor['name']) new_trust = self.create_trust(self.sample_data, self.trustor['name']) - context = self._create_auth_context( + request = self._create_auth_request( unscoped_token['access']['token']['id']) self.fetch_v2_token_from_trust(new_trust) trust_id = new_trust['id'] @@ -1124,7 +1135,7 @@ class AuthWithTrust(AuthTest): self.trustor['id'], trust_id=trust_id) self.assertEqual(1, len(tokens)) - self.trust_controller.delete_trust(context, trust_id=trust_id) + self.trust_controller.delete_trust(request, trust_id=trust_id) tokens = self.token_provider_api._persistence._list_tokens( self.trustor['id'], trust_id=trust_id) @@ -1138,7 +1149,7 @@ class AuthWithTrust(AuthTest): request_body = self.build_v2_token_request('TWO', 'two2', new_trust) self.assertRaises( exception.Forbidden, - self.controller.authenticate, {}, request_body) + self.controller.authenticate, self.make_request(), request_body) def test_expired_trust_get_token_fails(self): expires_at = (timeutils.utcnow() + @@ -1152,7 +1163,8 @@ class AuthWithTrust(AuthTest): new_trust) self.assertRaises( exception.Forbidden, - self.controller.authenticate, {}, request_body) + self.controller.authenticate, + self.make_request(), request_body) def test_token_from_trust_with_wrong_role_fails(self): new_trust = self.create_trust(self.sample_data, self.trustor['name']) @@ -1168,7 +1180,7 @@ class AuthWithTrust(AuthTest): self.assertRaises( exception.Forbidden, - self.controller.authenticate, {}, request_body) + self.controller.authenticate, self.make_request(), request_body) def test_do_not_consume_remaining_uses_when_get_token_fails(self): trust_data = copy.deepcopy(self.sample_data) @@ -1181,12 +1193,14 @@ class AuthWithTrust(AuthTest): request_body = self.build_v2_token_request('TWO', 'two2', new_trust) self.assertRaises(exception.Forbidden, - self.controller.authenticate, {}, request_body) + self.controller.authenticate, + self.make_request(), + request_body) unscoped_token = self.get_unscoped_token(self.trustor['name']) - context = self._create_auth_context( + request = self._create_auth_request( unscoped_token['access']['token']['id']) - trust = self.trust_controller.get_trust(context, + trust = self.trust_controller.get_trust(request, new_trust['id'])['trust'] self.assertEqual(3, trust['remaining_uses']) @@ -1202,7 +1216,7 @@ class AuthWithTrust(AuthTest): self.disable_user(self.trustor) self.assertRaises( exception.Forbidden, - self.controller.authenticate, {}, request_body) + self.controller.authenticate, self.make_request(), request_body) def test_trust_get_token_fails_if_trustee_disabled(self): new_trust = self.create_trust(self.sample_data, self.trustor['name']) @@ -1212,7 +1226,7 @@ class AuthWithTrust(AuthTest): self.disable_user(self.trustee) self.assertRaises( exception.Unauthorized, - self.controller.authenticate, {}, request_body) + self.controller.authenticate, self.make_request(), request_body) class TokenExpirationTest(AuthTest): @@ -1224,7 +1238,7 @@ class TokenExpirationTest(AuthTest): mock_utcnow.return_value = now r = self.controller.authenticate( - {}, + self.make_request(), auth={ 'passwordCredentials': { 'username': self.user_foo['name'], @@ -1237,14 +1251,14 @@ class TokenExpirationTest(AuthTest): mock_utcnow.return_value = now + datetime.timedelta(seconds=1) r = self.controller.validate_token( - dict(is_admin=True, query_string={}), + self.make_request(is_admin=True), token_id=unscoped_token_id) self.assertEqual(original_expiration, r['access']['token']['expires']) mock_utcnow.return_value = now + datetime.timedelta(seconds=2) r = self.controller.authenticate( - {}, + self.make_request(), auth={ 'token': { 'id': unscoped_token_id, @@ -1257,7 +1271,7 @@ class TokenExpirationTest(AuthTest): mock_utcnow.return_value = now + datetime.timedelta(seconds=3) r = self.controller.validate_token( - dict(is_admin=True, query_string={}), + self.make_request(is_admin=True), token_id=scoped_token_id) self.assertEqual(original_expiration, r['access']['token']['expires']) @@ -1321,7 +1335,7 @@ class AuthCatalog(unit.SQLDriverOverrides, AuthTest): password='foo2', tenant_name="BAR") - token = self.controller.authenticate({}, body_dict) + token = self.controller.authenticate(self.make_request(), body_dict) # Check the catalog self.assertEqual(1, len(token['access']['serviceCatalog'])) @@ -1347,12 +1361,12 @@ class AuthCatalog(unit.SQLDriverOverrides, AuthTest): password='foo2', tenant_name="BAR") - token = self.controller.authenticate({}, body_dict) + token = self.controller.authenticate(self.make_request(), body_dict) # Validate token_id = token['access']['token']['id'] validate_ref = self.controller.validate_token( - dict(is_admin=True, query_string={}), + self.make_request(is_admin=True), token_id=token_id) # Check the catalog diff --git a/keystone/tests/unit/test_auth_plugin.py b/keystone/tests/unit/test_auth_plugin.py index f0862ed6f..cb6888984 100644 --- a/keystone/tests/unit/test_auth_plugin.py +++ b/keystone/tests/unit/test_auth_plugin.py @@ -73,7 +73,7 @@ class TestAuthPlugin(unit.SQLDriverOverrides, unit.TestCase): auth_info = auth.controllers.AuthInfo.create(None, auth_data) auth_context = {'extras': {}, 'method_names': []} try: - self.api.authenticate({'environment': {}}, auth_info, auth_context) + self.api.authenticate(self.make_request(), auth_info, auth_context) except exception.AdditionalAuthRequired as e: self.assertIn('methods', e.authentication) self.assertIn(METHOD_NAME, e.authentication['methods']) @@ -87,7 +87,7 @@ class TestAuthPlugin(unit.SQLDriverOverrides, unit.TestCase): auth_data = {'identity': auth_data} auth_info = auth.controllers.AuthInfo.create(None, auth_data) auth_context = {'extras': {}, 'method_names': []} - self.api.authenticate({'environment': {}}, auth_info, auth_context) + self.api.authenticate(self.make_request(), auth_info, auth_context) self.assertEqual(DEMO_USER_ID, auth_context['user_id']) # test incorrect response @@ -99,7 +99,7 @@ class TestAuthPlugin(unit.SQLDriverOverrides, unit.TestCase): auth_context = {'extras': {}, 'method_names': []} self.assertRaises(exception.Unauthorized, self.api.authenticate, - {'environment': {}}, + self.make_request(), auth_info, auth_context) @@ -146,18 +146,19 @@ class TestMapped(unit.TestCase): with mock.patch.object(auth.plugins.mapped.Mapped, 'authenticate', return_value=None) as authenticate: - context = {'environment': {}} + request = self.make_request() auth_data = { 'identity': { 'methods': [method_name], method_name: {'protocol': method_name}, } } - auth_info = auth.controllers.AuthInfo.create(context, auth_data) + auth_info = auth.controllers.AuthInfo.create(request.context_dict, + auth_data) auth_context = {'extras': {}, 'method_names': [], 'user_id': uuid.uuid4().hex} - self.api.authenticate(context, auth_info, auth_context) + self.api.authenticate(request, auth_info, auth_context) # make sure Mapped plugin got invoked with the correct payload ((context, auth_payload, auth_context), kwargs) = authenticate.call_args @@ -178,8 +179,8 @@ class TestMapped(unit.TestCase): auth_context = {'extras': {}, 'method_names': [], 'user_id': uuid.uuid4().hex} - environment = {'environment': {'REMOTE_USER': 'foo@idp.com'}} - self.api.authenticate(environment, auth_info, auth_context) + request = self.make_request(environ={'REMOTE_USER': 'foo@idp.com'}) + self.api.authenticate(request, auth_info, auth_context) # make sure Mapped plugin got invoked with the correct payload ((context, auth_payload, auth_context), kwargs) = authenticate.call_args diff --git a/keystone/tests/unit/test_cert_setup.py b/keystone/tests/unit/test_cert_setup.py index 0ac7f045f..895bc060d 100644 --- a/keystone/tests/unit/test_cert_setup.py +++ b/keystone/tests/unit/test_cert_setup.py @@ -76,7 +76,7 @@ class CertSetupTestCase(rest.RestfulTestCase): } self.assertRaises(exception.UnexpectedError, controller.authenticate, - {}, body_dict) + self.make_request(), body_dict) def test_create_pki_certs(self, rebuild=False): pki = openssl.ConfigurePKI(None, None, rebuild=rebuild) diff --git a/keystone/tests/unit/test_v2_controller.py b/keystone/tests/unit/test_v2_controller.py index 28bebd790..e6d169b04 100644 --- a/keystone/tests/unit/test_v2_controller.py +++ b/keystone/tests/unit/test_v2_controller.py @@ -12,8 +12,6 @@ # License for the specific language governing permissions and limitations # under the License. - -import copy import uuid from testtools import matchers @@ -26,9 +24,6 @@ from keystone.tests.unit import default_fixtures from keystone.tests.unit.ksfixtures import database -_ADMIN_CONTEXT = {'is_admin': True, 'query_string': {}} - - class TenantTestCase(unit.TestCase): """Test for the V2 Tenant controller. @@ -57,18 +52,19 @@ class TenantTestCase(unit.TestCase): project_id = self.tenant_bar['id'] orig_project_users = ( - self.assignment_tenant_controller.get_project_users(_ADMIN_CONTEXT, - project_id)) + self.assignment_tenant_controller.get_project_users( + self.make_request(is_admin=True), project_id)) # Assign a role to a user that doesn't exist to the `bar` project. user_id = uuid.uuid4().hex self.assignment_role_controller.add_role_to_user( - _ADMIN_CONTEXT, user_id, self.role_other['id'], project_id) + self.make_request(is_admin=True), user_id, + self.role_other['id'], project_id) new_project_users = ( - self.assignment_tenant_controller.get_project_users(_ADMIN_CONTEXT, - project_id)) + self.assignment_tenant_controller.get_project_users( + self.make_request(is_admin=True), project_id)) # The new user isn't included in the result, so no change. # asserting that the expected values appear in the list, @@ -93,7 +89,8 @@ class TenantTestCase(unit.TestCase): # Now list all projects using the v2 API - we should only get # back those in the default features, since only those are in the # default domain. - refs = self.tenant_controller.get_all_projects(_ADMIN_CONTEXT) + refs = self.tenant_controller.get_all_projects( + self.make_request(is_admin=True)) self.assertEqual(len(default_fixtures.TENANTS), len(refs['tenants'])) for tenant in default_fixtures.TENANTS: tenant_copy = tenant.copy() @@ -111,21 +108,21 @@ class TenantTestCase(unit.TestCase): """Test that get project does not return is_domain projects.""" project = self._create_is_domain_project() - context = copy.deepcopy(_ADMIN_CONTEXT) - context['query_string']['name'] = project['name'] + request = self.make_request(is_admin=True) + request.context_dict['query_string']['name'] = project['name'] self.assertRaises( exception.ProjectNotFound, self.tenant_controller.get_all_projects, - context) + request) - context = copy.deepcopy(_ADMIN_CONTEXT) - context['query_string']['name'] = project['id'] + request = self.make_request(is_admin=True) + request.context_dict['query_string']['name'] = project['id'] self.assertRaises( exception.ProjectNotFound, self.tenant_controller.get_all_projects, - context) + request) def test_create_is_domain_project_fails(self): """Test that the creation of a project acting as a domain fails.""" @@ -135,7 +132,7 @@ class TenantTestCase(unit.TestCase): self.assertRaises( exception.ValidationError, self.tenant_controller.create_project, - _ADMIN_CONTEXT, + self.make_request(is_admin=True), project) def test_create_project_passing_is_domain_false_fails(self): @@ -146,7 +143,7 @@ class TenantTestCase(unit.TestCase): self.assertRaises( exception.ValidationError, self.tenant_controller.create_project, - _ADMIN_CONTEXT, + self.make_request(is_admin=True), project) def test_update_is_domain_project_not_found(self): @@ -157,7 +154,7 @@ class TenantTestCase(unit.TestCase): self.assertRaises( exception.ProjectNotFound, self.tenant_controller.update_project, - _ADMIN_CONTEXT, + self.make_request(is_admin=True), project['id'], project) @@ -168,7 +165,7 @@ class TenantTestCase(unit.TestCase): self.assertRaises( exception.ProjectNotFound, self.tenant_controller.delete_project, - _ADMIN_CONTEXT, + self.make_request(is_admin=True), project['id']) def test_list_is_domain_project_not_found(self): @@ -179,7 +176,8 @@ class TenantTestCase(unit.TestCase): project1 = self._create_is_domain_project() project2 = self._create_is_domain_project() - refs = self.tenant_controller.get_all_projects(_ADMIN_CONTEXT) + refs = self.tenant_controller.get_all_projects( + self.make_request(is_admin=True)) projects = refs.get('tenants') self.assertNotIn(project1, projects) diff --git a/keystone/tests/unit/test_v3.py b/keystone/tests/unit/test_v3.py index 256b4b62f..d39feb718 100644 --- a/keystone/tests/unit/test_v3.py +++ b/keystone/tests/unit/test_v3.py @@ -1301,17 +1301,16 @@ class RestfulTestCase(unit.SQLDriverOverrides, rest.RestfulTestCase, def build_external_auth_request(self, remote_user, remote_domain=None, auth_data=None, kerberos=False): - context = {'environment': {'REMOTE_USER': remote_user, - 'AUTH_TYPE': 'Negotiate'}} + environment = {'REMOTE_USER': remote_user, 'AUTH_TYPE': 'Negotiate'} if remote_domain: - context['environment']['REMOTE_DOMAIN'] = remote_domain + environment['REMOTE_DOMAIN'] = remote_domain if not auth_data: auth_data = self.build_authentication_request( kerberos=kerberos)['auth'] no_context = None auth_info = auth.controllers.AuthInfo.create(no_context, auth_data) auth_context = {'extras': {}, 'method_names': []} - return context, auth_info, auth_context + return self.make_request(environ=environment), auth_info, auth_context class VersionTestCase(RestfulTestCase): diff --git a/keystone/tests/unit/test_v3_auth.py b/keystone/tests/unit/test_v3_auth.py index 4e8b4153f..114bd6c1d 100644 --- a/keystone/tests/unit/test_v3_auth.py +++ b/keystone/tests/unit/test_v3_auth.py @@ -2493,11 +2493,11 @@ class TestAuthExternalDisabled(test_v3.RestfulTestCase): def test_remote_user_disabled(self): api = auth.controllers.Auth() remote_user = '%s@%s' % (self.user['name'], self.domain['name']) - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( remote_user) self.assertRaises(exception.Unauthorized, api.authenticate, - context, + request, auth_info, auth_context) @@ -2514,10 +2514,10 @@ class TestAuthExternalDomain(test_v3.RestfulTestCase): api = auth.controllers.Auth() remote_user = self.user['name'] remote_domain = self.domain['name'] - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( remote_user, remote_domain=remote_domain, kerberos=self.kerberos) - api.authenticate(context, auth_info, auth_context) + api.authenticate(request, auth_info, auth_context) self.assertEqual(self.user['id'], auth_context['user_id']) # Now test to make sure the user name can, itself, contain the @@ -2525,10 +2525,10 @@ class TestAuthExternalDomain(test_v3.RestfulTestCase): user = {'name': 'myname@mydivision'} self.identity_api.update_user(self.user['id'], user) remote_user = user['name'] - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( remote_user, remote_domain=remote_domain, kerberos=self.kerberos) - api.authenticate(context, auth_info, auth_context) + api.authenticate(request, auth_info, auth_context) self.assertEqual(self.user['id'], auth_context['user_id']) def test_project_id_scoped_with_remote_user(self): @@ -2570,10 +2570,10 @@ class TestAuthExternalDefaultDomain(test_v3.RestfulTestCase): def test_remote_user_with_default_domain(self): api = auth.controllers.Auth() remote_user = self.default_domain_user['name'] - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( remote_user, kerberos=self.kerberos) - api.authenticate(context, auth_info, auth_context) + api.authenticate(request, auth_info, auth_context) self.assertEqual(self.default_domain_user['id'], auth_context['user_id']) @@ -2582,10 +2582,10 @@ class TestAuthExternalDefaultDomain(test_v3.RestfulTestCase): user = {'name': 'myname@mydivision'} self.identity_api.update_user(self.default_domain_user['id'], user) remote_user = user['name'] - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( remote_user, kerberos=self.kerberos) - api.authenticate(context, auth_info, auth_context) + api.authenticate(request, auth_info, auth_context) self.assertEqual(self.default_domain_user['id'], auth_context['user_id']) @@ -3204,28 +3204,28 @@ class TestAuth(test_v3.RestfulTestCase): def test_remote_user_no_realm(self): api = auth.controllers.Auth() - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( self.default_domain_user['name']) - api.authenticate(context, auth_info, auth_context) + api.authenticate(request, auth_info, auth_context) self.assertEqual(self.default_domain_user['id'], auth_context['user_id']) # Now test to make sure the user name can, itself, contain the # '@' character. user = {'name': 'myname@mydivision'} self.identity_api.update_user(self.default_domain_user['id'], user) - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( user["name"]) - api.authenticate(context, auth_info, auth_context) + api.authenticate(request, auth_info, auth_context) self.assertEqual(self.default_domain_user['id'], auth_context['user_id']) def test_remote_user_no_domain(self): api = auth.controllers.Auth() - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( self.user['name']) self.assertRaises(exception.Unauthorized, api.authenticate, - context, + request, auth_info, auth_context) @@ -3237,10 +3237,10 @@ class TestAuth(test_v3.RestfulTestCase): user_domain_id=self.default_domain_user['domain_id'], username=self.default_domain_user['name'], password=self.default_domain_user['password'])['auth'] - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( self.default_domain_user['name'], auth_data=auth_data) - api.authenticate(context, auth_info, auth_context) + api.authenticate(request, auth_info, auth_context) def test_remote_user_and_explicit_external(self): # both REMOTE_USER and password methods must pass. @@ -3256,7 +3256,7 @@ class TestAuth(test_v3.RestfulTestCase): auth_context = {'extras': {}, 'method_names': []} self.assertRaises(exception.Unauthorized, api.authenticate, - self.empty_context, + self.make_request(), auth_info, auth_context) @@ -3267,11 +3267,11 @@ class TestAuth(test_v3.RestfulTestCase): user_domain_id=self.domain['id'], username=self.user['name'], password='badpassword')['auth'] - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( self.default_domain_user['name'], auth_data=auth_data) self.assertRaises(exception.Unauthorized, api.authenticate, - context, + request, auth_info, auth_context) @@ -3639,11 +3639,11 @@ class TestAuthJSONExternal(test_v3.RestfulTestCase): def test_remote_user_no_method(self): api = auth.controllers.Auth() - context, auth_info, auth_context = self.build_external_auth_request( + request, auth_info, auth_context = self.build_external_auth_request( self.default_domain_user['name']) self.assertRaises(exception.Unauthorized, api.authenticate, - context, + request, auth_info, auth_context) @@ -4618,11 +4618,10 @@ class TestAPIProtectionWithoutAuthContextMiddleware(test_v3.RestfulTestCase): auth_controller = auth.controllers.Auth() # all we care is that auth context is not in the environment and # 'token_id' is used to build the auth context instead - context = {'subject_token_id': token, - 'token_id': token, - 'query_string': {}, - 'environment': {}} - r = auth_controller.validate_token(context) + request = self.make_request() + request.context_dict['subject_token_id'] = token + request.context_dict['token_id'] = token + r = auth_controller.validate_token(request) self.assertEqual(http_client.OK, r.status_code) diff --git a/keystone/tests/unit/test_v3_federation.py b/keystone/tests/unit/test_v3_federation.py index 56c95115d..2346cffdb 100644 --- a/keystone/tests/unit/test_v3_federation.py +++ b/keystone/tests/unit/test_v3_federation.py @@ -159,11 +159,11 @@ class FederatedSetupMixin(object): assertion='EMPLOYEE_ASSERTION', environment=None): api = federation_controllers.Auth() - context = {'environment': environment or {}} - self._inject_assertion(context, assertion) + request = self.make_request(environ=environment or {}) + self._inject_assertion(request, assertion) if idp is None: idp = self.IDP - r = api.federated_authentication(context, idp, self.PROTOCOL) + r = api.federated_authentication(request, idp, self.PROTOCOL) return r def idp_ref(self, id=None): @@ -206,10 +206,10 @@ class FederatedSetupMixin(object): } } - def _inject_assertion(self, context, variant, query_string=None): + def _inject_assertion(self, request, variant, query_string=None): assertion = getattr(mapping_fixtures, variant) - context['environment'].update(assertion) - context['query_string'] = query_string or [] + request.context_dict['environment'].update(assertion) + request.context_dict['query_string'] = query_string or [] def load_federation_sample_data(self): """Inject additional data.""" @@ -719,15 +719,15 @@ class FederatedSetupMixin(object): self.proto_saml['id'], self.proto_saml) # Generate fake tokens - context = {'environment': {}} + request = self.make_request() self.tokens = {} VARIANTS = ('EMPLOYEE_ASSERTION', 'CUSTOMER_ASSERTION', 'ADMIN_ASSERTION') api = auth_controllers.Auth() for variant in VARIANTS: - self._inject_assertion(context, variant) - r = api.authenticate_for_token(context, self.UNSCOPED_V3_SAML2_REQ) + self._inject_assertion(request, variant) + r = api.authenticate_for_token(request, self.UNSCOPED_V3_SAML2_REQ) self.tokens[variant] = r.headers.get('X-Subject-Token') self.TOKEN_SCOPE_PROJECT_FROM_NONEXISTENT_TOKEN = self._scope_request( @@ -1759,16 +1759,14 @@ class FederatedTokenTests(test_v3.RestfulTestCase, FederatedSetupMixin): """ api = auth_controllers.Auth() - context = { - 'environment': { - 'malformed_object': object(), - 'another_bad_idea': tuple(range(10)), - 'yet_another_bad_param': dict(zip(uuid.uuid4().hex, - range(32))) - } + environ = { + 'malformed_object': object(), + 'another_bad_idea': tuple(range(10)), + 'yet_another_bad_param': dict(zip(uuid.uuid4().hex, range(32))) } - self._inject_assertion(context, 'EMPLOYEE_ASSERTION') - r = api.authenticate_for_token(context, self.UNSCOPED_V3_SAML2_REQ) + request = self.make_request(environ=environ) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION') + r = api.authenticate_for_token(request, self.UNSCOPED_V3_SAML2_REQ) self.assertIsNotNone(r.headers.get('X-Subject-Token')) def test_scope_to_project_once_notify(self): @@ -1858,11 +1856,11 @@ class FederatedTokenTests(test_v3.RestfulTestCase, FederatedSetupMixin): def test_issue_token_from_rules_without_user(self): api = auth_controllers.Auth() - context = {'environment': {}} - self._inject_assertion(context, 'BAD_TESTER_ASSERTION') + request = self.make_request() + self._inject_assertion(request, 'BAD_TESTER_ASSERTION') self.assertRaises(exception.Unauthorized, api.authenticate_for_token, - context, self.UNSCOPED_V3_SAML2_REQ) + request, self.UNSCOPED_V3_SAML2_REQ) def test_issue_token_with_nonexistent_group(self): """Inject assertion that matches rule issuing bad group id. @@ -3547,10 +3545,10 @@ class WebSSOTests(FederatedTokenTests): def test_federated_sso_auth(self): environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} - context = {'environment': environment} + request = self.make_request(environ=environment) query_string = {'origin': self.ORIGIN} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION', query_string) - resp = self.api.federated_sso_auth(context, self.PROTOCOL) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) + resp = self.api.federated_sso_auth(request, self.PROTOCOL) # `resp.body` will be `str` in Python 2 and `bytes` in Python 3 # which is why expected value: `self.TRUSTED_DASHBOARD` # needs to be encoded @@ -3577,10 +3575,10 @@ class WebSSOTests(FederatedTokenTests): remote_id_attribute=self.PROTOCOL_REMOTE_ID_ATTR) environment = {self.PROTOCOL_REMOTE_ID_ATTR: self.REMOTE_IDS[0]} - context = {'environment': environment} + request = self.make_request(environ=environment) query_string = {'origin': self.ORIGIN} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION', query_string) - resp = self.api.federated_sso_auth(context, self.PROTOCOL) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) + resp = self.api.federated_sso_auth(request, self.PROTOCOL) # `resp.body` will be `str` in Python 2 and `bytes` in Python 3 # which is why expected value: `self.TRUSTED_DASHBOARD` # needs to be encoded @@ -3588,61 +3586,61 @@ class WebSSOTests(FederatedTokenTests): def test_federated_sso_auth_bad_remote_id(self): environment = {self.REMOTE_ID_ATTR: self.IDP} - context = {'environment': environment} + request = self.make_request(environ=environment) query_string = {'origin': self.ORIGIN} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION', query_string) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) self.assertRaises(exception.IdentityProviderNotFound, self.api.federated_sso_auth, - context, self.PROTOCOL) + request, self.PROTOCOL) def test_federated_sso_missing_query(self): environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} - context = {'environment': environment} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION') + request = self.make_request(environ=environment) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION') self.assertRaises(exception.ValidationError, self.api.federated_sso_auth, - context, self.PROTOCOL) + request, self.PROTOCOL) def test_federated_sso_missing_query_bad_remote_id(self): environment = {self.REMOTE_ID_ATTR: self.IDP} - context = {'environment': environment} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION') + request = self.make_request(environ=environment) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION') self.assertRaises(exception.ValidationError, self.api.federated_sso_auth, - context, self.PROTOCOL) + request, self.PROTOCOL) def test_federated_sso_untrusted_dashboard(self): environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} - context = {'environment': environment} + request = self.make_request(environ=environment) query_string = {'origin': uuid.uuid4().hex} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION', query_string) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) self.assertRaises(exception.Unauthorized, self.api.federated_sso_auth, - context, self.PROTOCOL) + request, self.PROTOCOL) def test_federated_sso_untrusted_dashboard_bad_remote_id(self): environment = {self.REMOTE_ID_ATTR: self.IDP} - context = {'environment': environment} + request = self.make_request(environ=environment) query_string = {'origin': uuid.uuid4().hex} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION', query_string) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) self.assertRaises(exception.Unauthorized, self.api.federated_sso_auth, - context, self.PROTOCOL) + request, self.PROTOCOL) def test_federated_sso_missing_remote_id(self): - context = {'environment': {}} + request = self.make_request() query_string = {'origin': self.ORIGIN} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION', query_string) + self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) self.assertRaises(exception.Unauthorized, self.api.federated_sso_auth, - context, self.PROTOCOL) + request, self.PROTOCOL) def test_identity_provider_specific_federated_authentication(self): environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} - context = {'environment': environment} + request = self.make_request(environ=environment) query_string = {'origin': self.ORIGIN} - self._inject_assertion(context, 'EMPLOYEE_ASSERTION', query_string) - resp = self.api.federated_idp_specific_sso_auth(context, + self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) + resp = self.api.federated_idp_specific_sso_auth(request, self.idp['id'], self.PROTOCOL) # `resp.body` will be `str` in Python 2 and `bytes` in Python 3 diff --git a/keystone/tests/unit/test_wsgi.py b/keystone/tests/unit/test_wsgi.py index b5060d3b7..ee7010365 100644 --- a/keystone/tests/unit/test_wsgi.py +++ b/keystone/tests/unit/test_wsgi.py @@ -33,13 +33,13 @@ from keystone.tests import unit class FakeApp(wsgi.Application): - def index(self, context): + def index(self, request): return {'a': 'b'} class FakeAttributeCheckerApp(wsgi.Application): - def index(self, context): - return context['query_string'] + def index(self, request): + return request.context_dict['query_string'] def assert_attribute(self, body, attr): """Assert that the given request has a certain attribute.""" @@ -88,16 +88,16 @@ class ApplicationTest(BaseWSGITest): def test_query_string_available(self): class FakeApp(wsgi.Application): - def index(self, context): - return context['query_string'] + def index(self, request): + return request.context_dict['query_string'] req = self._make_request(url='/?1=2') resp = req.get_response(FakeApp()) self.assertEqual({'1': '2'}, jsonutils.loads(resp.body)) def test_headers_available(self): class FakeApp(wsgi.Application): - def index(self, context): - return context['headers'] + def index(self, request): + return request.context_dict['headers'] app = FakeApp() req = self._make_request(url='/?1=2') @@ -221,8 +221,8 @@ class ApplicationTest(BaseWSGITest): def test_improperly_encoded_params(self): class FakeApp(wsgi.Application): - def index(self, context): - return context['query_string'] + def index(self, request): + return request.context_dict['query_string'] # this is high bit set ASCII, copy & pasted from Windows. # aka code page 1252. It is not valid UTF8. req = self._make_request(url='/?name=nonexit%E8nt') @@ -231,8 +231,8 @@ class ApplicationTest(BaseWSGITest): def test_properly_encoded_params(self): class FakeApp(wsgi.Application): - def index(self, context): - return context['query_string'] + def index(self, request): + return request.context_dict['query_string'] # nonexitènt encoded as UTF-8 req = self._make_request(url='/?name=nonexit%C3%A8nt') resp = req.get_response(FakeApp()) @@ -241,8 +241,8 @@ class ApplicationTest(BaseWSGITest): def test_base_url(self): class FakeApp(wsgi.Application): - def index(self, context): - return self.base_url(context, 'public') + def index(self, request): + return self.base_url(request.context_dict, 'public') req = self._make_request(url='/') # NOTE(gyee): according to wsgiref, if HTTP_HOST is present in the # request environment, it will be used to construct the base url. @@ -582,7 +582,7 @@ class LocalizedResponseTest(unit.TestCase): # Fake app raises NotFound exception to simulate Keystone raising. class FakeApp(wsgi.Application): - def index(self, context): + def index(self, request): raise exception.NotFound(target=target) # Make the request with Accept-Language on the app, expect an error diff --git a/keystone/token/controllers.py b/keystone/token/controllers.py index ecffa2964..a045ebd6a 100644 --- a/keystone/token/controllers.py +++ b/keystone/token/controllers.py @@ -48,19 +48,19 @@ class ExternalAuthNotApplicable(Exception): class Auth(controller.V2Controller): @controller.v2_deprecated - def ca_cert(self, context, auth=None): + def ca_cert(self, request, auth=None): with open(CONF.signing.ca_certs, 'r') as ca_file: data = ca_file.read() return data @controller.v2_deprecated - def signing_cert(self, context, auth=None): + def signing_cert(self, request, auth=None): with open(CONF.signing.certfile, 'r') as cert_file: data = cert_file.read() return data @controller.v2_auth_deprecated - def authenticate(self, context, auth=None): + def authenticate(self, request, auth=None): """Authenticate credentials and return a token. Accept auth as a dict that looks like:: @@ -88,16 +88,16 @@ class Auth(controller.V2Controller): if "token" in auth: # Try to authenticate using a token auth_info = self._authenticate_token( - context, auth) + request.context_dict, auth) else: # Try external authentication try: auth_info = self._authenticate_external( - context, auth) + request.context_dict, auth) except ExternalAuthNotApplicable: # Try local authentication auth_info = self._authenticate_local( - context, auth) + request.context_dict, auth) user_ref, tenant_ref, metadata_ref, expiry, bind, audit_id = auth_info # Validate that the auth info is valid and nothing is disabled @@ -421,7 +421,7 @@ class Auth(controller.V2Controller): @controller.v2_deprecated @controller.protected() - def validate_token_head(self, context, token_id): + def validate_token_head(self, request, token_id): """Check that a token is valid. Optionally, also ensure that it is owned by a specific tenant. @@ -432,12 +432,12 @@ class Auth(controller.V2Controller): the content body. """ - belongs_to = context['query_string'].get('belongsTo') + belongs_to = request.context_dict['query_string'].get('belongsTo') return self.token_provider_api.validate_v2_token(token_id, belongs_to) @controller.v2_deprecated @controller.protected() - def validate_token(self, context, token_id): + def validate_token(self, request, token_id): """Check that a token is valid. Optionally, also ensure that it is owned by a specific tenant. @@ -445,20 +445,20 @@ class Auth(controller.V2Controller): Returns metadata about the token along any associated roles. """ - belongs_to = context['query_string'].get('belongsTo') + belongs_to = request.context_dict['query_string'].get('belongsTo') # TODO(ayoung) validate against revocation API return self.token_provider_api.validate_v2_token(token_id, belongs_to) @controller.v2_deprecated - def delete_token(self, context, token_id): + def delete_token(self, request, token_id): """Delete a token, effectively invalidating it for authz.""" # TODO(termie): this stuff should probably be moved to middleware - self.assert_admin(context) + self.assert_admin(request.context_dict) self.token_provider_api.revoke_token(token_id) @controller.v2_deprecated @controller.protected() - def revocation_list(self, context, auth=None): + def revocation_list(self, request, auth=None): if not CONF.token.revoke_by_id: raise exception.Gone() tokens = self.token_provider_api.list_revoked_tokens() @@ -476,9 +476,9 @@ class Auth(controller.V2Controller): return {'signed': signed_text} @controller.v2_deprecated - def endpoints(self, context, token_id): + def endpoints(self, request, token_id): """Return a list of endpoints available to the token.""" - self.assert_admin(context) + self.assert_admin(request.context_dict) token_ref = self._get_token_ref(token_id) diff --git a/keystone/trust/controllers.py b/keystone/trust/controllers.py index 8ba6ec3e4..9fb6ad5c8 100644 --- a/keystone/trust/controllers.py +++ b/keystone/trust/controllers.py @@ -60,13 +60,13 @@ class TrustV3(controller.V3Controller): return None return token_ref.user_id - def get_trust(self, context, trust_id): - user_id = self._get_user_id(context) + def get_trust(self, request, trust_id): + user_id = self._get_user_id(request.context_dict) trust = self.trust_api.get_trust(trust_id) _trustor_trustee_only(trust, user_id) - self._fill_in_roles(context, trust, + self._fill_in_roles(request.context_dict, trust, self.role_api.list_roles()) - return TrustV3.wrap_member(context, trust) + return TrustV3.wrap_member(request.context_dict, trust) def _fill_in_roles(self, context, trust, all_roles): if trust.get('expires_at') is not None: @@ -113,14 +113,14 @@ class TrustV3(controller.V3Controller): @controller.protected() @validation.validated(schema.trust_create, 'trust') - def create_trust(self, context, trust): + def create_trust(self, request, trust): """Create a new trust. The user creating the trust must be the trustor. """ - auth_context = context.get('environment', - {}).get('KEYSTONE_AUTH_CONTEXT', {}) + env = request.context_dict.get('environment', {}) + auth_context = env.get('KEYSTONE_AUTH_CONTEXT', {}) # Check if delegated via trust if auth_context.get('is_delegated_auth'): @@ -136,7 +136,7 @@ class TrustV3(controller.V3Controller): if trust.get('project_id'): self._require_role(trust) - self._require_user_is_trustor(context, trust) + self._require_user_is_trustor(request.context_dict, trust) self._require_trustee_exists(trust['trustee_user_id']) all_roles = self.role_api.list_roles() # Normalize roles @@ -146,13 +146,13 @@ class TrustV3(controller.V3Controller): trust['expires_at'] = self._parse_expiration_date( trust.get('expires_at')) trust_id = uuid.uuid4().hex - initiator = notifications._get_request_audit_info(context) + initiator = notifications._get_request_audit_info(request.context_dict) new_trust = self.trust_api.create_trust(trust_id, trust, normalized_roles, redelegated_trust, initiator) - self._fill_in_roles(context, new_trust, all_roles) - return TrustV3.wrap_member(context, new_trust) + self._fill_in_roles(request.context_dict, new_trust, all_roles) + return TrustV3.wrap_member(request.context_dict, new_trust) def _require_trustee_exists(self, trustee_user_id): self.identity_api.get_user(trustee_user_id) @@ -215,22 +215,22 @@ class TrustV3(controller.V3Controller): raise exception.RoleNotFound(role_id=role_id) @controller.protected() - def list_trusts(self, context): - query = context['query_string'] + def list_trusts(self, request): + query = request.context_dict['query_string'] trusts = [] if not query: - self.assert_admin(context) + self.assert_admin(request.context_dict) trusts += self.trust_api.list_trusts() if 'trustor_user_id' in query: user_id = query['trustor_user_id'] - calling_user_id = self._get_user_id(context) + calling_user_id = self._get_user_id(request.context_dict) if user_id != calling_user_id: raise exception.Forbidden() trusts += (self.trust_api. list_trusts_for_trustor(user_id)) if 'trustee_user_id' in query: user_id = query['trustee_user_id'] - calling_user_id = self._get_user_id(context) + calling_user_id = self._get_user_id(request.context_dict) if user_id != calling_user_id: raise exception.Forbidden() trusts += self.trust_api.list_trusts_for_trustee(user_id) @@ -244,27 +244,28 @@ class TrustV3(controller.V3Controller): trust['expires_at'] = (utils.isotime (trust['expires_at'], subsecond=True)) - return TrustV3.wrap_collection(context, trusts) + return TrustV3.wrap_collection(request.context_dict, trusts) @controller.protected() - def delete_trust(self, context, trust_id): + def delete_trust(self, request, trust_id): trust = self.trust_api.get_trust(trust_id) - user_id = self._get_user_id(context) - _admin_trustor_only(context, trust, user_id) - initiator = notifications._get_request_audit_info(context) + user_id = self._get_user_id(request.context_dict) + _admin_trustor_only(request.context_dict, trust, user_id) + initiator = notifications._get_request_audit_info(request.context_dict) self.trust_api.delete_trust(trust_id, initiator) @controller.protected() - def list_roles_for_trust(self, context, trust_id): - trust = self.get_trust(context, trust_id)['trust'] - user_id = self._get_user_id(context) + def list_roles_for_trust(self, request, trust_id): + trust = self.get_trust(request, trust_id)['trust'] + user_id = self._get_user_id(request.context_dict) _trustor_trustee_only(trust, user_id) return {'roles': trust['roles'], 'links': trust['roles_links']} @controller.protected() - def get_role_for_trust(self, context, trust_id, role_id): + def get_role_for_trust(self, request, trust_id, role_id): """Get a role that has been assigned to a trust.""" - self._check_role_for_trust(context, trust_id, role_id) + self._check_role_for_trust(request.context_dict, trust_id, role_id) role = self.role_api.get_role(role_id) - return assignment.controllers.RoleV3.wrap_member(context, role) + return assignment.controllers.RoleV3.wrap_member(request.context_dict, + role) diff --git a/keystone/v2_crud/user_crud.py b/keystone/v2_crud/user_crud.py index 9da7f31f3..6f8159bd7 100644 --- a/keystone/v2_crud/user_crud.py +++ b/keystone/v2_crud/user_crud.py @@ -12,7 +12,6 @@ # License for the specific language governing permissions and limitations # under the License. -import copy import uuid from oslo_log import log @@ -50,8 +49,8 @@ extension.register_public_extension( @dependency.requires('catalog_api', 'identity_api', 'resource_api', 'token_provider_api') class UserController(identity.controllers.User): - def set_user_password(self, context, user_id, user): - token_id = context.get('token_id') + def set_user_password(self, request, user_id, user): + token_id = request.context_dict.get('token_id') original_password = user.get('original_password') token_data = self.token_provider_api.validate_token(token_id) @@ -66,7 +65,7 @@ class UserController(identity.controllers.User): try: user_ref = self.identity_api.authenticate( - context, + request.context_dict, user_id=token_ref.user_id, password=original_password) if not user_ref.get('enabled', True): @@ -77,12 +76,15 @@ class UserController(identity.controllers.User): update_dict = {'password': user['password'], 'id': user_id} - admin_context = copy.copy(context) - admin_context['is_admin'] = True - super(UserController, self).set_user_password(admin_context, + old_admin = request.context_dict.pop('is_admin', False) + request.context_dict['is_admin'] = True + + super(UserController, self).set_user_password(request, user_id, update_dict) + request.context_dict['is_admin'] = old_admin + # Issue a new token based upon the original token data. This will # always be a V2.0 token. diff --git a/keystone/version/controllers.py b/keystone/version/controllers.py index 50f2ae4a6..f3230c5cd 100644 --- a/keystone/version/controllers.py +++ b/keystone/version/controllers.py @@ -56,10 +56,10 @@ class Extensions(wsgi.Application): def extensions(self): return None - def get_extensions_info(self, context): + def get_extensions_info(self, request): return {'extensions': {'values': list(self.extensions.values())}} - def get_extension_info(self, context, extension_alias): + def get_extension_info(self, request, extension_alias): try: return {'extension': self.extensions[extension_alias]} except KeyError: @@ -159,24 +159,24 @@ class Version(wsgi.Application): return versions - def get_versions(self, context): + def get_versions(self, request): - req_mime_type = v3_mime_type_best_match(context) + req_mime_type = v3_mime_type_best_match(request.context_dict) if req_mime_type == MimeTypes.JSON_HOME: v3_json_home = request_v3_json_home('/v3') return wsgi.render_response( body=v3_json_home, headers=(('Content-Type', MimeTypes.JSON_HOME),)) - versions = self._get_versions_list(context) + versions = self._get_versions_list(request.context_dict) return wsgi.render_response(status=(300, 'Multiple Choices'), body={ 'versions': { 'values': list(versions.values()) } }) - def get_version_v2(self, context): - versions = self._get_versions_list(context) + def get_version_v2(self, request): + versions = self._get_versions_list(request.context_dict) if 'v2.0' in _VERSIONS: return wsgi.render_response(body={ 'version': versions['v2.0'] @@ -195,10 +195,10 @@ class Version(wsgi.Application): 'resources': dict(all_resources()) } - def get_version_v3(self, context): - versions = self._get_versions_list(context) + def get_version_v3(self, request): + versions = self._get_versions_list(request.context_dict) if 'v3' in _VERSIONS: - req_mime_type = v3_mime_type_best_match(context) + req_mime_type = v3_mime_type_best_match(request.context_dict) if req_mime_type == MimeTypes.JSON_HOME: return wsgi.render_response(