Merge "Use request.params instead of context['query_string']"

This commit is contained in:
Jenkins 2016-06-29 02:48:52 +00:00 committed by Gerrit Code Review
commit 5583cd1df0
15 changed files with 97 additions and 103 deletions

View File

@ -58,8 +58,8 @@ class TenantAssignment(controller.V2Controller):
tenant_refs = [self.v3_to_v2_project(ref) for ref in tenant_refs tenant_refs = [self.v3_to_v2_project(ref) for ref in tenant_refs
if ref['domain_id'] == CONF.identity.default_domain_id] if ref['domain_id'] == CONF.identity.default_domain_id]
params = { params = {
'limit': request.context_dict['query_string'].get('limit'), 'limit': request.params.get('limit'),
'marker': request.context_dict['query_string'].get('marker'), 'marker': request.params.get('marker'),
} }
return self.format_project_list(tenant_refs, **params) return self.format_project_list(tenant_refs, **params)
@ -333,6 +333,8 @@ class RoleV3(controller.V3Controller):
def list_roles_wrapper(self, request): def list_roles_wrapper(self, request):
# If there is no domain_id filter defined, then we only want to return # If there is no domain_id filter defined, then we only want to return
# global roles, so we set the domain_id filter to None. # global roles, so we set the domain_id filter to None.
# NOTE(jamielennox): this is still using context_dict because it's
# writing to the query dict. Why is it writing to the query dict?
params = request.context_dict['query_string'] params = request.context_dict['query_string']
if 'domain_id' not in params: if 'domain_id' not in params:
request.context_dict['query_string']['domain_id'] = None request.context_dict['query_string']['domain_id'] = None
@ -871,7 +873,7 @@ class RoleAssignmentV3(controller.V3Controller):
msg = _('Specify a user or group, not both') msg = _('Specify a user or group, not both')
raise exception.ValidationError(msg) raise exception.ValidationError(msg)
def _list_role_assignments(self, context, filters, include_subtree=False): def _list_role_assignments(self, request, filters, include_subtree=False):
"""List role assignments to user and groups on domains and projects. """List role assignments to user and groups on domains and projects.
Return a list of all existing role assignments in the system, filtered Return a list of all existing role assignments in the system, filtered
@ -894,7 +896,7 @@ class RoleAssignmentV3(controller.V3Controller):
both user and group ids or domain and project ids is invalid as well. both user and group ids or domain and project ids is invalid as well.
""" """
params = context['query_string'] params = request.params
effective = 'effective' in params and ( effective = 'effective' in params and (
self.query_filter_is_true(params['effective'])) self.query_filter_is_true(params['effective']))
include_names = ('include_names' in params and include_names = ('include_names' in params and
@ -928,15 +930,16 @@ class RoleAssignmentV3(controller.V3Controller):
inherited=inherited, effective=effective, inherited=inherited, effective=effective,
include_names=include_names) include_names=include_names)
formatted_refs = [self._format_entity(context, ref) for ref in refs] formatted_refs = [self._format_entity(request.context_dict, ref)
for ref in refs]
return self.wrap_collection(context, formatted_refs) return self.wrap_collection(request.context_dict, formatted_refs)
@controller.filterprotected('group.id', 'role.id', @controller.filterprotected('group.id', 'role.id',
'scope.domain.id', 'scope.project.id', 'scope.domain.id', 'scope.project.id',
'scope.OS-INHERIT:inherited_to', 'user.id') 'scope.OS-INHERIT:inherited_to', 'user.id')
def list_role_assignments(self, request, filters): def list_role_assignments(self, request, filters):
return self._list_role_assignments(request.context_dict, filters) return self._list_role_assignments(request, filters)
def _check_list_tree_protection(self, context, protection_info): def _check_list_tree_protection(self, context, protection_info):
"""Check protection for list assignment for tree API. """Check protection for list assignment for tree API.
@ -958,11 +961,11 @@ class RoleAssignmentV3(controller.V3Controller):
'scope.OS-INHERIT:inherited_to', 'user.id', 'scope.OS-INHERIT:inherited_to', 'user.id',
callback=_check_list_tree_protection) callback=_check_list_tree_protection)
def list_role_assignments_for_tree(self, request, filters): def list_role_assignments_for_tree(self, request, filters):
if not request.context_dict['query_string'].get('scope.project.id'): if not request.params.get('scope.project.id'):
msg = _('scope.project.id must be specified if include_subtree ' msg = _('scope.project.id must be specified if include_subtree '
'is also specified') 'is also specified')
raise exception.ValidationError(message=msg) raise exception.ValidationError(message=msg)
return self._list_role_assignments(request.context_dict, filters, return self._list_role_assignments(request, filters,
include_subtree=True) include_subtree=True)
def list_role_assignments_wrapper(self, request): def list_role_assignments_wrapper(self, request):
@ -974,7 +977,7 @@ class RoleAssignmentV3(controller.V3Controller):
protected entry point. protected entry point.
""" """
params = request.context_dict['query_string'] params = request.params
if 'include_subtree' in params and ( if 'include_subtree' in params and (
self.query_filter_is_true(params['include_subtree'])): self.query_filter_is_true(params['include_subtree'])):
return self.list_role_assignments_for_tree(request) return self.list_role_assignments_for_tree(request)

View File

@ -392,8 +392,7 @@ class Auth(controller.V3Controller):
def authenticate_for_token(self, request, auth=None): def authenticate_for_token(self, request, auth=None):
"""Authenticate user and issue a token.""" """Authenticate user and issue a token."""
query_string = request.context_dict['query_string'] include_catalog = 'nocatalog' not in request.params
include_catalog = 'nocatalog' not in query_string
try: try:
auth_info = AuthInfo.create(request.context_dict, auth=auth) auth_info = AuthInfo.create(request.context_dict, auth=auth)
@ -558,8 +557,7 @@ class Auth(controller.V3Controller):
@controller.protected() @controller.protected()
def validate_token(self, request): def validate_token(self, request):
token_id = request.context_dict.get('subject_token_id') token_id = request.context_dict.get('subject_token_id')
query_string = request.context_dict['query_string'] include_catalog = 'nocatalog' not in request.params
include_catalog = 'nocatalog' not in query_string
token_data = self.token_provider_api.validate_v3_token( token_data = self.token_provider_api.validate_v3_token(
token_id) token_id)
if not include_catalog and 'catalog' in token_data['token']: if not include_catalog and 'catalog' in token_data['token']:
@ -571,8 +569,7 @@ class Auth(controller.V3Controller):
if not CONF.token.revoke_by_id: if not CONF.token.revoke_by_id:
raise exception.Gone() raise exception.Gone()
query_string = request.context_dict['query_string'] audit_id_only = 'audit_id_only' in request.params
audit_id_only = 'audit_id_only' in query_string
tokens = self.token_provider_api.list_revoked_tokens() tokens = self.token_provider_api.list_revoked_tokens()

View File

@ -52,7 +52,7 @@ class OAuth(auth.AuthMethodHandler):
result, request = access_verifier.validate_protected_resource_request( result, request = access_verifier.validate_protected_resource_request(
url, url,
http_method='POST', http_method='POST',
body=request.context_dict['query_string'], body=request.params,
headers=request.headers, headers=request.headers,
realms=None realms=None
) )

View File

@ -720,7 +720,7 @@ class V3Controller(wsgi.Application):
ref['id'] = uuid.uuid4().hex ref['id'] = uuid.uuid4().hex
return ref return ref
def _get_domain_id_for_list_request(self, context): def _get_domain_id_for_list_request(self, request):
"""Get the domain_id for a v3 list call. """Get the domain_id for a v3 list call.
If we running with multiple domain drivers, then the caller must If we running with multiple domain drivers, then the caller must
@ -731,10 +731,11 @@ class V3Controller(wsgi.Application):
# We don't need to specify a domain ID in this case # We don't need to specify a domain ID in this case
return return
if context['query_string'].get('domain_id') is not None: domain_id = request.params.get('domain_id')
return context['query_string'].get('domain_id') if domain_id:
return domain_id
token_ref = utils.get_token_ref(context) token_ref = utils.get_token_ref(request.context_dict)
if token_ref.domain_scoped: if token_ref.domain_scoped:
return token_ref.domain_id return token_ref.domain_id

View File

@ -250,7 +250,7 @@ class MappingController(_ControllerBase):
@dependency.requires('federation_api') @dependency.requires('federation_api')
class Auth(auth_controllers.Auth): class Auth(auth_controllers.Auth):
def _get_sso_origin_host(self, context): def _get_sso_origin_host(self, request):
"""Validate and return originating dashboard URL. """Validate and return originating dashboard URL.
Make sure the parameter is specified in the request's URL as well its Make sure the parameter is specified in the request's URL as well its
@ -264,14 +264,15 @@ class Auth(auth_controllers.Auth):
:returns: URL with the originating dashboard :returns: URL with the originating dashboard
""" """
if 'origin' in context['query_string']: origin = request.params.get('origin')
origin = context['query_string']['origin']
host = urllib.parse.unquote_plus(origin) if not origin:
else:
msg = _('Request must have an origin query parameter') msg = _('Request must have an origin query parameter')
LOG.error(msg) LOG.error(msg)
raise exception.ValidationError(msg) raise exception.ValidationError(msg)
host = urllib.parse.unquote_plus(origin)
# change trusted_dashboard hostnames to lowercase before comparison # change trusted_dashboard hostnames to lowercase before comparison
trusted_dashboards = [k_utils.lower_case_hostname(trusted) trusted_dashboards = [k_utils.lower_case_hostname(trusted)
for trusted in CONF.federation.trusted_dashboard] for trusted in CONF.federation.trusted_dashboard]
@ -312,7 +313,7 @@ class Auth(auth_controllers.Auth):
LOG.error(msg) LOG.error(msg)
raise exception.Unauthorized(msg) raise exception.Unauthorized(msg)
host = self._get_sso_origin_host(request.context_dict) host = self._get_sso_origin_host(request)
ref = self.federation_api.get_idp_from_remote_id(remote_id) ref = self.federation_api.get_idp_from_remote_id(remote_id)
# NOTE(stevemar): the returned object is a simple dict that # NOTE(stevemar): the returned object is a simple dict that
@ -325,7 +326,7 @@ class Auth(auth_controllers.Auth):
return self.render_html_response(host, token_id) return self.render_html_response(host, token_id)
def federated_idp_specific_sso_auth(self, request, idp_id, protocol_id): def federated_idp_specific_sso_auth(self, request, idp_id, protocol_id):
host = self._get_sso_origin_host(request.context_dict) host = self._get_sso_origin_host(request)
# NOTE(lbragstad): We validate that the Identity Provider actually # NOTE(lbragstad): We validate that the Identity Provider actually
# exists in the Mapped authentication plugin. # exists in the Mapped authentication plugin.

View File

@ -43,10 +43,8 @@ class User(controller.V2Controller):
def get_users(self, request): def get_users(self, request):
# NOTE(termie): i can't imagine that this really wants all the data # NOTE(termie): i can't imagine that this really wants all the data
# about every single user in the system... # about every single user in the system...
if 'name' in request.context_dict['query_string']: if 'name' in request.params:
return self.get_user_by_name( return self.get_user_by_name(request, request.params['name'])
request,
request.context_dict['query_string'].get('name'))
self.assert_admin(request.context_dict) self.assert_admin(request.context_dict)
user_list = self.identity_api.list_users( user_list = self.identity_api.list_users(
@ -230,7 +228,7 @@ class UserV3(controller.V3Controller):
@controller.filterprotected('domain_id', 'enabled', 'name') @controller.filterprotected('domain_id', 'enabled', 'name')
def list_users(self, request, filters): def list_users(self, request, filters):
hints = UserV3.build_driver_hints(request.context_dict, filters) hints = UserV3.build_driver_hints(request.context_dict, filters)
domain = self._get_domain_id_for_list_request(request.context_dict) domain = self._get_domain_id_for_list_request(request)
refs = self.identity_api.list_users(domain_scope=domain, hints=hints) refs = self.identity_api.list_users(domain_scope=domain, hints=hints)
return UserV3.wrap_collection(request.context_dict, refs, hints=hints) return UserV3.wrap_collection(request.context_dict, refs, hints=hints)
@ -323,7 +321,7 @@ class GroupV3(controller.V3Controller):
@controller.filterprotected('domain_id', 'name') @controller.filterprotected('domain_id', 'name')
def list_groups(self, request, filters): def list_groups(self, request, filters):
hints = GroupV3.build_driver_hints(request.context_dict, filters) hints = GroupV3.build_driver_hints(request.context_dict, filters)
domain = self._get_domain_id_for_list_request(request.context_dict) domain = self._get_domain_id_for_list_request(request)
refs = self.identity_api.list_groups(domain_scope=domain, hints=hints) refs = self.identity_api.list_groups(domain_scope=domain, hints=hints)
return GroupV3.wrap_collection(request.context_dict, refs, hints=hints) return GroupV3.wrap_collection(request.context_dict, refs, hints=hints)

View File

@ -240,7 +240,7 @@ class OAuthControllerV3(controller.V3Controller):
h, b, s = request_verifier.create_request_token_response( h, b, s = request_verifier.create_request_token_response(
url, url,
http_method='POST', http_method='POST',
body=request.context_dict['query_string'], body=request.params,
headers=req_headers) headers=req_headers)
if (not b) or int(s) > 399: if (not b) or int(s) > 399:
@ -305,7 +305,7 @@ class OAuthControllerV3(controller.V3Controller):
h, b, s = access_verifier.create_access_token_response( h, b, s = access_verifier.create_access_token_response(
url, url,
http_method='POST', http_method='POST',
body=request.context_dict['query_string'], body=request.params,
headers=headers) headers=headers)
params = oauth1.extract_non_oauth_params(b) params = oauth1.extract_non_oauth_params(b)
if params: if params:

View File

@ -40,7 +40,7 @@ class Tenant(controller.V2Controller):
"""Get a list of all tenants for an admin user.""" """Get a list of all tenants for an admin user."""
self.assert_admin(request.context_dict) self.assert_admin(request.context_dict)
name = request.context_dict['query_string'].get('name') name = request.params.get('name')
if name: if name:
return self._get_project_by_name(name) return self._get_project_by_name(name)
@ -55,8 +55,8 @@ class Tenant(controller.V2Controller):
for tenant_ref in tenant_refs for tenant_ref in tenant_refs
if not tenant_ref.get('is_domain')] if not tenant_ref.get('is_domain')]
params = { params = {
'limit': request.context_dict['query_string'].get('limit'), 'limit': request.params.get('limit'),
'marker': request.context_dict['query_string'].get('marker'), 'marker': request.params.get('marker'),
} }
return self.format_project_list(tenant_refs, **params) return self.format_project_list(tenant_refs, **params)
@ -263,14 +263,15 @@ class ProjectV3(controller.V3Controller):
hints = ProjectV3.build_driver_hints(request.context_dict, filters) hints = ProjectV3.build_driver_hints(request.context_dict, filters)
# If 'is_domain' has not been included as a query, we default it to # If 'is_domain' has not been included as a query, we default it to
# False (which in query terms means '0' # False (which in query terms means '0'
if 'is_domain' not in request.context_dict['query_string']: if 'is_domain' not in request.params:
hints.add_filter('is_domain', '0') hints.add_filter('is_domain', '0')
refs = self.resource_api.list_projects(hints=hints) refs = self.resource_api.list_projects(hints=hints)
return ProjectV3.wrap_collection(request.context_dict, return ProjectV3.wrap_collection(request.context_dict,
refs, hints=hints) refs, hints=hints)
def _expand_project_ref(self, context, ref): def _expand_project_ref(self, request, ref):
params = context['query_string'] params = request.params
context = request.context_dict
parents_as_list = 'parents_as_list' in params and ( parents_as_list = 'parents_as_list' in params and (
self.query_filter_is_true(params['parents_as_list'])) self.query_filter_is_true(params['parents_as_list']))
@ -316,7 +317,7 @@ class ProjectV3(controller.V3Controller):
@controller.protected() @controller.protected()
def get_project(self, request, project_id): def get_project(self, request, project_id):
ref = self.resource_api.get_project(project_id) ref = self.resource_api.get_project(project_id)
self._expand_project_ref(request.context_dict, ref) self._expand_project_ref(request, ref)
return ProjectV3.wrap_member(request.context_dict, ref) return ProjectV3.wrap_member(request.context_dict, ref)
@controller.protected() @controller.protected()

View File

@ -22,7 +22,7 @@ from keystone.i18n import _
class RevokeController(controller.V3Controller): class RevokeController(controller.V3Controller):
@controller.protected() @controller.protected()
def list_revoke_events(self, request): def list_revoke_events(self, request):
since = request.context_dict['query_string'].get('since') since = request.params.get('since')
last_fetch = None last_fetch = None
if since: if since:
try: try:

View File

@ -574,11 +574,10 @@ class TestCase(BaseTestCase):
def make_request(self, path='/', **kwargs): def make_request(self, path='/', **kwargs):
context = {} context = {}
for k in ('is_admin', 'query_string'): try:
try: context['is_admin'] = kwargs.pop('is_admin')
context[k] = kwargs.pop(k) except KeyError:
except KeyError: pass
pass
req = request.Request.blank(path=path, **kwargs) req = request.Request.blank(path=path, **kwargs)
req.context_dict.update(context) req.context_dict.update(context)

View File

@ -410,8 +410,7 @@ class AuthWithToken(AuthTest):
self.assertRaises( self.assertRaises(
exception.Unauthorized, exception.Unauthorized,
self.controller.validate_token, self.controller.validate_token,
self.make_request(is_admin=True, self.make_request(is_admin=True, query_string='belongsTo=BAR'),
query_string={'belongsTo': 'BAR'}),
token_id=unscoped_token_id) token_id=unscoped_token_id)
def test_belongs_to(self): def test_belongs_to(self):
@ -427,14 +426,13 @@ class AuthWithToken(AuthTest):
self.assertRaises( self.assertRaises(
exception.Unauthorized, exception.Unauthorized,
self.controller.validate_token, self.controller.validate_token,
self.make_request(is_admin=True, query_string={'belongsTo': 'me'}), self.make_request(is_admin=True, query_string='belongsTo=me'),
token_id=scoped_token_id) token_id=scoped_token_id)
self.assertRaises( self.assertRaises(
exception.Unauthorized, exception.Unauthorized,
self.controller.validate_token, self.controller.validate_token,
self.make_request(is_admin=True, self.make_request(is_admin=True, query_string='belongsTo=BAR'),
query_string={'belongsTo': 'BAR'}),
token_id=scoped_token_id) token_id=scoped_token_id)
def test_token_auth_with_binding(self): def test_token_auth_with_binding(self):

View File

@ -108,16 +108,16 @@ class TenantTestCase(unit.TestCase):
"""Test that get project does not return is_domain projects.""" """Test that get project does not return is_domain projects."""
project = self._create_is_domain_project() project = self._create_is_domain_project()
request = self.make_request(is_admin=True) request = self.make_request(is_admin=True,
request.context_dict['query_string']['name'] = project['name'] query_string='name=%s' % project['name'])
self.assertRaises( self.assertRaises(
exception.ProjectNotFound, exception.ProjectNotFound,
self.tenant_controller.get_all_projects, self.tenant_controller.get_all_projects,
request) request)
request = self.make_request(is_admin=True) request = self.make_request(is_admin=True,
request.context_dict['query_string']['name'] = project['id'] query_string='name=%s' % project['id'])
self.assertRaises( self.assertRaises(
exception.ProjectNotFound, exception.ProjectNotFound,

View File

@ -159,8 +159,9 @@ class FederatedSetupMixin(object):
assertion='EMPLOYEE_ASSERTION', assertion='EMPLOYEE_ASSERTION',
environment=None): environment=None):
api = federation_controllers.Auth() api = federation_controllers.Auth()
request = self.make_request(environ=environment or {}) environment = environment or {}
self._inject_assertion(request, assertion) environment.update(getattr(mapping_fixtures, assertion))
request = self.make_request(environ=environment)
if idp is None: if idp is None:
idp = self.IDP idp = self.IDP
r = api.federated_authentication(request, idp, self.PROTOCOL) r = api.federated_authentication(request, idp, self.PROTOCOL)
@ -206,10 +207,9 @@ class FederatedSetupMixin(object):
} }
} }
def _inject_assertion(self, request, variant, query_string=None): def _inject_assertion(self, request, variant):
assertion = getattr(mapping_fixtures, variant) assertion = getattr(mapping_fixtures, variant)
request.context_dict['environment'].update(assertion) request.context_dict['environment'].update(assertion)
request.context_dict['query_string'] = query_string or []
def load_federation_sample_data(self): def load_federation_sample_data(self):
"""Inject additional data.""" """Inject additional data."""
@ -1764,8 +1764,8 @@ class FederatedTokenTests(test_v3.RestfulTestCase, FederatedSetupMixin):
'another_bad_idea': tuple(range(10)), 'another_bad_idea': tuple(range(10)),
'yet_another_bad_param': dict(zip(uuid.uuid4().hex, range(32))) 'yet_another_bad_param': dict(zip(uuid.uuid4().hex, range(32)))
} }
environ.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environ) request = self.make_request(environ=environ)
self._inject_assertion(request, 'EMPLOYEE_ASSERTION')
r = api.authenticate_for_token(request, self.UNSCOPED_V3_SAML2_REQ) r = api.authenticate_for_token(request, self.UNSCOPED_V3_SAML2_REQ)
self.assertIsNotNone(r.headers.get('X-Subject-Token')) self.assertIsNotNone(r.headers.get('X-Subject-Token'))
@ -1856,8 +1856,8 @@ class FederatedTokenTests(test_v3.RestfulTestCase, FederatedSetupMixin):
def test_issue_token_from_rules_without_user(self): def test_issue_token_from_rules_without_user(self):
api = auth_controllers.Auth() api = auth_controllers.Auth()
request = self.make_request() environ = copy.deepcopy(mapping_fixtures.BAD_TESTER_ASSERTION)
self._inject_assertion(request, 'BAD_TESTER_ASSERTION') request = self.make_request(environ=environ)
self.assertRaises(exception.Unauthorized, self.assertRaises(exception.Unauthorized,
api.authenticate_for_token, api.authenticate_for_token,
request, self.UNSCOPED_V3_SAML2_REQ) request, self.UNSCOPED_V3_SAML2_REQ)
@ -3643,10 +3643,10 @@ class WebSSOTests(FederatedTokenTests):
self.assertIn(self.TRUSTED_DASHBOARD.encode('utf-8'), resp.body) self.assertIn(self.TRUSTED_DASHBOARD.encode('utf-8'), resp.body)
def test_federated_sso_auth(self): def test_federated_sso_auth(self):
environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0],
'QUERY_STRING': 'origin=%s' % self.ORIGIN}
environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environment) request = self.make_request(environ=environment)
query_string = {'origin': self.ORIGIN}
self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string)
resp = self.api.federated_sso_auth(request, self.PROTOCOL) resp = self.api.federated_sso_auth(request, self.PROTOCOL)
# `resp.body` will be `str` in Python 2 and `bytes` in Python 3 # `resp.body` will be `str` in Python 2 and `bytes` in Python 3
# which is why expected value: `self.TRUSTED_DASHBOARD` # which is why expected value: `self.TRUSTED_DASHBOARD`
@ -3655,17 +3655,14 @@ class WebSSOTests(FederatedTokenTests):
def test_get_sso_origin_host_case_insensitive(self): def test_get_sso_origin_host_case_insensitive(self):
# test lowercase hostname in trusted_dashboard # test lowercase hostname in trusted_dashboard
context = { environ = {'QUERY_STRING': 'origin=http://horizon.com'}
'query_string': { request = self.make_request(environ=environ)
'origin': "http://horizon.com", host = self.api._get_sso_origin_host(request)
},
}
host = self.api._get_sso_origin_host(context)
self.assertEqual("http://horizon.com", host) self.assertEqual("http://horizon.com", host)
# test uppercase hostname in trusted_dashboard # test uppercase hostname in trusted_dashboard
self.config_fixture.config(group='federation', self.config_fixture.config(group='federation',
trusted_dashboard=['http://Horizon.com']) trusted_dashboard=['http://Horizon.com'])
host = self.api._get_sso_origin_host(context) host = self.api._get_sso_origin_host(request)
self.assertEqual("http://horizon.com", host) self.assertEqual("http://horizon.com", host)
def test_federated_sso_auth_with_protocol_specific_remote_id(self): def test_federated_sso_auth_with_protocol_specific_remote_id(self):
@ -3673,10 +3670,10 @@ class WebSSOTests(FederatedTokenTests):
group=self.PROTOCOL, group=self.PROTOCOL,
remote_id_attribute=self.PROTOCOL_REMOTE_ID_ATTR) remote_id_attribute=self.PROTOCOL_REMOTE_ID_ATTR)
environment = {self.PROTOCOL_REMOTE_ID_ATTR: self.REMOTE_IDS[0]} environment = {self.PROTOCOL_REMOTE_ID_ATTR: self.REMOTE_IDS[0],
'QUERY_STRING': 'origin=%s' % self.ORIGIN}
environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environment) request = self.make_request(environ=environment)
query_string = {'origin': self.ORIGIN}
self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string)
resp = self.api.federated_sso_auth(request, self.PROTOCOL) resp = self.api.federated_sso_auth(request, self.PROTOCOL)
# `resp.body` will be `str` in Python 2 and `bytes` in Python 3 # `resp.body` will be `str` in Python 2 and `bytes` in Python 3
# which is why expected value: `self.TRUSTED_DASHBOARD` # which is why expected value: `self.TRUSTED_DASHBOARD`
@ -3684,61 +3681,61 @@ class WebSSOTests(FederatedTokenTests):
self.assertIn(self.TRUSTED_DASHBOARD.encode('utf-8'), resp.body) self.assertIn(self.TRUSTED_DASHBOARD.encode('utf-8'), resp.body)
def test_federated_sso_auth_bad_remote_id(self): def test_federated_sso_auth_bad_remote_id(self):
environment = {self.REMOTE_ID_ATTR: self.IDP} environment = {self.REMOTE_ID_ATTR: self.IDP,
'QUERY_STRING': 'origin=%s' % self.ORIGIN}
environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environment) request = self.make_request(environ=environment)
query_string = {'origin': self.ORIGIN}
self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string)
self.assertRaises(exception.IdentityProviderNotFound, self.assertRaises(exception.IdentityProviderNotFound,
self.api.federated_sso_auth, self.api.federated_sso_auth,
request, self.PROTOCOL) request, self.PROTOCOL)
def test_federated_sso_missing_query(self): def test_federated_sso_missing_query(self):
environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]}
environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environment) request = self.make_request(environ=environment)
self._inject_assertion(request, 'EMPLOYEE_ASSERTION')
self.assertRaises(exception.ValidationError, self.assertRaises(exception.ValidationError,
self.api.federated_sso_auth, self.api.federated_sso_auth,
request, self.PROTOCOL) request, self.PROTOCOL)
def test_federated_sso_missing_query_bad_remote_id(self): def test_federated_sso_missing_query_bad_remote_id(self):
environment = {self.REMOTE_ID_ATTR: self.IDP} environment = {self.REMOTE_ID_ATTR: self.IDP}
environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environment) request = self.make_request(environ=environment)
self._inject_assertion(request, 'EMPLOYEE_ASSERTION')
self.assertRaises(exception.ValidationError, self.assertRaises(exception.ValidationError,
self.api.federated_sso_auth, self.api.federated_sso_auth,
request, self.PROTOCOL) request, self.PROTOCOL)
def test_federated_sso_untrusted_dashboard(self): def test_federated_sso_untrusted_dashboard(self):
environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0],
'QUERY_STRING': 'origin=%s' % uuid.uuid4().hex}
environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environment) request = self.make_request(environ=environment)
query_string = {'origin': uuid.uuid4().hex}
self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string)
self.assertRaises(exception.Unauthorized, self.assertRaises(exception.Unauthorized,
self.api.federated_sso_auth, self.api.federated_sso_auth,
request, self.PROTOCOL) request, self.PROTOCOL)
def test_federated_sso_untrusted_dashboard_bad_remote_id(self): def test_federated_sso_untrusted_dashboard_bad_remote_id(self):
environment = {self.REMOTE_ID_ATTR: self.IDP} environment = {self.REMOTE_ID_ATTR: self.IDP,
'QUERY_STRING': 'origin=%s' % uuid.uuid4().hex}
environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
request = self.make_request(environ=environment) request = self.make_request(environ=environment)
query_string = {'origin': uuid.uuid4().hex}
self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string)
self.assertRaises(exception.Unauthorized, self.assertRaises(exception.Unauthorized,
self.api.federated_sso_auth, self.api.federated_sso_auth,
request, self.PROTOCOL) request, self.PROTOCOL)
def test_federated_sso_missing_remote_id(self): def test_federated_sso_missing_remote_id(self):
request = self.make_request() environment = copy.deepcopy(mapping_fixtures.EMPLOYEE_ASSERTION)
query_string = {'origin': self.ORIGIN} request = self.make_request(environ=environment,
self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) query_string='origin=%s' % self.ORIGIN)
self.assertRaises(exception.Unauthorized, self.assertRaises(exception.Unauthorized,
self.api.federated_sso_auth, self.api.federated_sso_auth,
request, self.PROTOCOL) request, self.PROTOCOL)
def test_identity_provider_specific_federated_authentication(self): def test_identity_provider_specific_federated_authentication(self):
environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]} environment = {self.REMOTE_ID_ATTR: self.REMOTE_IDS[0]}
request = self.make_request(environ=environment) environment.update(mapping_fixtures.EMPLOYEE_ASSERTION)
query_string = {'origin': self.ORIGIN} request = self.make_request(environ=environment,
self._inject_assertion(request, 'EMPLOYEE_ASSERTION', query_string) query_string='origin=%s' % self.ORIGIN)
resp = self.api.federated_idp_specific_sso_auth(request, resp = self.api.federated_idp_specific_sso_auth(request,
self.idp['id'], self.idp['id'],
self.PROTOCOL) self.PROTOCOL)

View File

@ -432,7 +432,7 @@ class Auth(controller.V2Controller):
the content body. the content body.
""" """
belongs_to = request.context_dict['query_string'].get('belongsTo') belongs_to = request.params.get('belongsTo')
return self.token_provider_api.validate_v2_token(token_id, belongs_to) return self.token_provider_api.validate_v2_token(token_id, belongs_to)
@controller.v2_deprecated @controller.v2_deprecated
@ -445,7 +445,7 @@ class Auth(controller.V2Controller):
Returns metadata about the token along any associated roles. Returns metadata about the token along any associated roles.
""" """
belongs_to = request.context_dict['query_string'].get('belongsTo') belongs_to = request.params.get('belongsTo')
# TODO(ayoung) validate against revocation API # TODO(ayoung) validate against revocation API
return self.token_provider_api.validate_v2_token(token_id, belongs_to) return self.token_provider_api.validate_v2_token(token_id, belongs_to)

View File

@ -216,20 +216,19 @@ class TrustV3(controller.V3Controller):
@controller.protected() @controller.protected()
def list_trusts(self, request): def list_trusts(self, request):
query = request.context_dict['query_string']
trusts = [] trusts = []
if not query: if not request.params:
self.assert_admin(request.context_dict) self.assert_admin(request.context_dict)
trusts += self.trust_api.list_trusts() trusts += self.trust_api.list_trusts()
if 'trustor_user_id' in query: if 'trustor_user_id' in request.params:
user_id = query['trustor_user_id'] user_id = request.params['trustor_user_id']
calling_user_id = self._get_user_id(request.context_dict) calling_user_id = self._get_user_id(request.context_dict)
if user_id != calling_user_id: if user_id != calling_user_id:
raise exception.Forbidden() raise exception.Forbidden()
trusts += (self.trust_api. trusts += (self.trust_api.
list_trusts_for_trustor(user_id)) list_trusts_for_trustor(user_id))
if 'trustee_user_id' in query: if 'trustee_user_id' in request.params:
user_id = query['trustee_user_id'] user_id = request.params['trustee_user_id']
calling_user_id = self._get_user_id(request.context_dict) calling_user_id = self._get_user_id(request.context_dict)
if user_id != calling_user_id: if user_id != calling_user_id:
raise exception.Forbidden() raise exception.Forbidden()