diff --git a/keystone/common/utils.py b/keystone/common/utils.py index b61baeaa35..ebe2bd41fc 100644 --- a/keystone/common/utils.py +++ b/keystone/common/utils.py @@ -577,3 +577,20 @@ def lower_case_hostname(url): # Note: _replace method for named tuples is public and defined in docs replaced = parsed._replace(netloc=parsed.netloc.lower()) return moves.urllib.parse.urlunparse(replaced) + + +def remove_standard_port(url): + # remove the default ports specified in RFC2616 and 2818 + o = moves.urllib.parse.urlparse(url) + separator = ':' + (host, separator, port) = o.netloc.partition(':') + if o.scheme.lower() == 'http' and port == '80': + # NOTE(gyee): _replace() is not a private method. It has an + # an underscore prefix to prevent conflict with field names. + # See https://docs.python.org/2/library/collections.html# + # collections.namedtuple + o = o._replace(netloc=host) + if o.scheme.lower() == 'https' and port == '443': + o = o._replace(netloc=host) + + return moves.urllib.parse.urlunparse(o) diff --git a/keystone/common/wsgi.py b/keystone/common/wsgi.py index 55752b7740..0c3ea8e310 100644 --- a/keystone/common/wsgi.py +++ b/keystone/common/wsgi.py @@ -20,6 +20,7 @@ import copy import itertools +import re import wsgiref.util from oslo_config import cfg @@ -376,13 +377,19 @@ class Application(BaseApplication): itertools.chain(CONF.items(), CONF.eventlet_server.items())) url = url % substitutions + elif 'environment' in context: + url = wsgiref.util.application_uri(context['environment']) + # remove version from the URL as it may be part of SCRIPT_NAME but + # it should not be part of base URL + url = re.sub(r'/v(3|(2\.0))/*$', '', url) + + # now remove the standard port + url = utils.remove_standard_port(url) else: - # NOTE(jamielennox): If url is not set via the config file we - # should set it relative to the url that the user used to get here - # so as not to mess with version discovery. This is not perfect. - # host_url omits the path prefix, but there isn't another good - # solution that will work for all urls. - url = context['host_url'] + # if we don't have enough information to come up with a base URL, + # then fall back to localhost. This should never happen in + # production environment. + url = 'http://localhost:%d' % CONF.eventlet_server.public_port return url.rstrip('/') @@ -812,18 +819,15 @@ def render_exception(error, context=None, request=None, user_locale=None): if isinstance(error, exception.AuthPluginException): body['error']['identity'] = error.authentication elif isinstance(error, exception.Unauthorized): - url = CONF.public_endpoint - if not url: - if request: - context = {'host_url': request.host_url} - if context: - url = Application.base_url(context, 'public') - else: - url = 'http://localhost:%d' % CONF.eventlet_server.public_port - else: - substitutions = dict( - itertools.chain(CONF.items(), CONF.eventlet_server.items())) - url = url % substitutions + # NOTE(gyee): we only care about the request environment in the + # context. Also, its OK to pass the environemt as it is read-only in + # Application.base_url() + local_context = {} + if request: + local_context = {'environment': request.environ} + elif context and 'environment' in context: + local_context = {'environment': context['environment']} + url = Application.base_url(local_context, 'public') headers.append(('WWW-Authenticate', 'Keystone uri="%s"' % url)) return render_response(status=(error.code, error.title), diff --git a/keystone/tests/unit/test_auth.py b/keystone/tests/unit/test_auth.py index 92a25b15d7..1b7320264a 100644 --- a/keystone/tests/unit/test_auth.py +++ b/keystone/tests/unit/test_auth.py @@ -14,6 +14,8 @@ import copy import datetime +import random +import string import uuid import mock @@ -41,7 +43,9 @@ CONF = cfg.CONF TIME_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' DEFAULT_DOMAIN_ID = CONF.identity.default_domain_id -HOST_URL = 'http://keystone:5001' +HOST = ''.join(random.choice(string.ascii_lowercase) for x in range( + random.randint(5, 15))) +HOST_URL = 'http://%s' % (HOST) def _build_user_auth(token=None, user_id=None, username=None, @@ -871,7 +875,16 @@ class AuthWithTrust(AuthTest): token_id=token_id, token_data=self.token_provider_api.validate_token(token_id)) auth_context = authorization.token_to_auth_context(token_ref) - return {'environment': {authorization.AUTH_CONTEXT_ENV: auth_context}, + # NOTE(gyee): if public_endpoint and admin_endpoint are not set, which + # is the default, the base url will be constructed from the environment + # variables wsgi.url_scheme, SERVER_NAME, SERVER_PORT, and SCRIPT_NAME. + # We have to set them in the context so the base url can be constructed + # accordingly. + return {'environment': {authorization.AUTH_CONTEXT_ENV: auth_context, + 'wsgi.url_scheme': 'http', + 'SCRIPT_NAME': '/v3', + 'SERVER_PORT': '80', + 'SERVER_NAME': HOST}, 'token_id': token_id, 'host_url': HOST_URL} diff --git a/keystone/tests/unit/test_wsgi.py b/keystone/tests/unit/test_wsgi.py index e621cd54cd..564d7406cc 100644 --- a/keystone/tests/unit/test_wsgi.py +++ b/keystone/tests/unit/test_wsgi.py @@ -213,7 +213,9 @@ class ApplicationTest(BaseWSGITest): def test_render_exception_host(self): e = exception.Unauthorized(message=u'\u7f51\u7edc') - context = {'host_url': 'http://%s:5000' % uuid.uuid4().hex} + req = self._make_request(url='/') + context = {'host_url': 'http://%s:5000' % uuid.uuid4().hex, + 'environment': req.environ} resp = wsgi.render_exception(e, context=context) self.assertEqual(http_client.UNAUTHORIZED, resp.status_int) @@ -238,6 +240,77 @@ class ApplicationTest(BaseWSGITest): self.assertEqual({'name': u'nonexit\xe8nt'}, jsonutils.loads(resp.body)) + def test_base_url(self): + class FakeApp(wsgi.Application): + def index(self, context): + return self.base_url(context, 'public') + req = self._make_request(url='/') + # NOTE(gyee): according to wsgiref, if HTTP_HOST is present in the + # request environment, it will be used to construct the base url. + # SERVER_NAME and SERVER_PORT will be ignored. These are standard + # WSGI environment variables populated by the webserver. + req.environ.update({ + 'SCRIPT_NAME': '/identity', + 'SERVER_NAME': '1.2.3.4', + 'wsgi.url_scheme': 'http', + 'SERVER_PORT': '80', + 'HTTP_HOST': '1.2.3.4', + }) + resp = req.get_response(FakeApp()) + self.assertEqual(b"http://1.2.3.4/identity", resp.body) + + # if HTTP_HOST is absent, SERVER_NAME and SERVER_PORT will be used + req = self._make_request(url='/') + del req.environ['HTTP_HOST'] + req.environ.update({ + 'SCRIPT_NAME': '/identity', + 'SERVER_NAME': '1.1.1.1', + 'wsgi.url_scheme': 'http', + 'SERVER_PORT': '1234', + }) + resp = req.get_response(FakeApp()) + self.assertEqual(b"http://1.1.1.1:1234/identity", resp.body) + + # make sure keystone normalize the standard HTTP port 80 by stripping + # it + req = self._make_request(url='/') + req.environ.update({'HTTP_HOST': 'foo:80', + 'SCRIPT_NAME': '/identity'}) + resp = req.get_response(FakeApp()) + self.assertEqual(b"http://foo/identity", resp.body) + + # make sure keystone normalize the standard HTTPS port 443 by stripping + # it + req = self._make_request(url='/') + req.environ.update({'HTTP_HOST': 'foo:443', + 'SCRIPT_NAME': '/identity', + 'wsgi.url_scheme': 'https'}) + resp = req.get_response(FakeApp()) + self.assertEqual(b"https://foo/identity", resp.body) + + # make sure non-standard port is preserved + req = self._make_request(url='/') + req.environ.update({'HTTP_HOST': 'foo:1234', + 'SCRIPT_NAME': '/identity'}) + resp = req.get_response(FakeApp()) + self.assertEqual(b"http://foo:1234/identity", resp.body) + + # make sure version portion of the SCRIPT_NAME, '/v2.0', is stripped + # from base url + req = self._make_request(url='/') + req.environ.update({'HTTP_HOST': 'foo:80', + 'SCRIPT_NAME': '/bar/identity/v2.0'}) + resp = req.get_response(FakeApp()) + self.assertEqual(b"http://foo/bar/identity", resp.body) + + # make sure version portion of the SCRIPT_NAME, '/v3' is stripped from + # base url + req = self._make_request(url='/') + req.environ.update({'HTTP_HOST': 'foo:80', + 'SCRIPT_NAME': '/identity/v3'}) + resp = req.get_response(FakeApp()) + self.assertEqual(b"http://foo/identity", resp.body) + class ExtensionRouterTest(BaseWSGITest): def test_extensionrouter_local_config(self):