Merge "OAuth 2.0 Mutual-TLS Support"

This commit is contained in:
Zuul 2023-03-03 17:14:58 +00:00 committed by Gerrit Code Review
commit c08d97672d
17 changed files with 2428 additions and 111 deletions

View File

@ -16,16 +16,22 @@ import flask
from flask import make_response
import http.client
from oslo_log import log
from oslo_serialization import jsonutils
from keystone.api._shared import authentication
from keystone.api._shared import json_home_relations
from keystone.common import provider_api
from keystone.common import utils
from keystone.conf import CONF
from keystone import exception
from keystone.federation import utils as federation_utils
from keystone.i18n import _
from keystone.server import flask as ks_flask
LOG = log.getLogger(__name__)
PROVIDERS = provider_api.ProviderAPIs
_build_resource_relation = json_home_relations.os_oauth2_resource_rel_func
@ -69,44 +75,6 @@ class AccessTokenResource(ks_flask.ResourceBase):
POST /v3/OS-OAUTH2/token
"""
client_auth = flask.request.authorization
if not client_auth:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('OAuth2.0 client authorization is required.'))
LOG.info('Get OAuth2.0 Access Token API: '
'field \'authorization\' is not found in HTTP Headers.')
raise error
if client_auth.type != 'basic':
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('OAuth2.0 client authorization type %s is not supported.')
% client_auth.type)
LOG.info('Get OAuth2.0 Access Token API: '
f'{error.message_format}')
raise error
client_id = client_auth.username
client_secret = client_auth.password
if not client_id:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('OAuth2.0 client authorization is invalid.'))
LOG.info('Get OAuth2.0 Access Token API: '
'client_id is not found in authorization.')
raise error
if not client_secret:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('OAuth2.0 client authorization is invalid.'))
LOG.info('Get OAuth2.0 Access Token API: '
'client_secret is not found in authorization.')
raise error
grant_type = flask.request.form.get('grant_type')
if grant_type is None:
error = exception.OAuth2InvalidRequest(
@ -125,6 +93,45 @@ class AccessTokenResource(ks_flask.ResourceBase):
LOG.info('Get OAuth2.0 Access Token API: '
f'{error.message_format}')
raise error
auth_method = ''
client_id = flask.request.form.get('client_id')
client_secret = flask.request.form.get('client_secret')
client_cert = flask.request.environ.get("SSL_CLIENT_CERT")
client_auth = flask.request.authorization
if not client_cert and client_auth and client_auth.type == 'basic':
client_id = client_auth.username
client_secret = client_auth.password
if not client_id:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'failed to get a client_id from the request.')
raise error
if client_cert:
auth_method = 'tls_client_auth'
elif client_secret:
auth_method = 'client_secret_basic'
if auth_method in CONF.oauth2.oauth2_authn_methods:
if auth_method == 'tls_client_auth':
return self._tls_client_auth(client_id, client_cert)
if auth_method == 'client_secret_basic':
return self._client_secret_basic(client_id, client_secret)
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'failed to get client credentials from the request.')
raise error
def _client_secret_basic(self, client_id, client_secret):
"""Get an OAuth2.0 basic Access Token."""
auth_data = {
'identity': {
'methods': ['application_credential'],
@ -168,6 +175,202 @@ class AccessTokenResource(ks_flask.ResourceBase):
resp.status = '200 OK'
return resp
def _check_mapped_properties(self, cert_dn, user, user_domain):
mapping_id = CONF.oauth2.get('oauth2_cert_dn_mapping_id')
try:
mapping = PROVIDERS.federation_api.get_mapping(mapping_id)
except exception.MappingNotFound:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'mapping id %s is not found. ',
mapping_id)
raise error
rule_processor = federation_utils.RuleProcessor(
mapping.get('id'), mapping.get('rules'))
try:
mapped_properties = rule_processor.process(cert_dn)
except exception.Error as error:
LOG.exception(error)
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'mapping rule process failed. '
'mapping_id: %s, rules: %s, data: %s.',
mapping_id, mapping.get('rules'),
jsonutils.dumps(cert_dn))
raise error
except Exception as error:
LOG.exception(error)
error = exception.OAuth2OtherError(
int(http.client.INTERNAL_SERVER_ERROR),
http.client.responses[http.client.INTERNAL_SERVER_ERROR],
str(error))
LOG.info('Get OAuth2.0 Access Token API: '
'mapping rule process failed. '
'mapping_id: %s, rules: %s, data: %s.',
mapping_id, mapping.get('rules'),
jsonutils.dumps(cert_dn))
raise error
mapping_user = mapped_properties.get('user', {})
mapping_user_name = mapping_user.get('name')
mapping_user_id = mapping_user.get('id')
mapping_user_email = mapping_user.get('email')
mapping_domain = mapping_user.get('domain', {})
mapping_user_domain_id = mapping_domain.get('id')
mapping_user_domain_name = mapping_domain.get('name')
if mapping_user_name and mapping_user_name != user.get('name'):
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: %s check failed. '
'DN value: %s, DB value: %s.',
'user name', mapping_user_name, user.get('name'))
raise error
if mapping_user_id and mapping_user_id != user.get('id'):
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: %s check failed. '
'DN value: %s, DB value: %s.',
'user id', mapping_user_id, user.get('id'))
raise error
if mapping_user_email and mapping_user_email != user.get('email'):
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: %s check failed. '
'DN value: %s, DB value: %s.',
'user email', mapping_user_email, user.get('email'))
raise error
if (mapping_user_domain_id and
mapping_user_domain_id != user_domain.get('id')):
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: %s check failed. '
'DN value: %s, DB value: %s.',
'user domain id', mapping_user_domain_id,
user_domain.get('id'))
raise error
if (mapping_user_domain_name and
mapping_user_domain_name != user_domain.get('name')):
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: %s check failed. '
'DN value: %s, DB value: %s.',
'user domain name', mapping_user_domain_name,
user_domain.get('name'))
raise error
def _tls_client_auth(self, client_id, client_cert):
"""Get an OAuth2.0 certificate-bound Access Token."""
try:
cert_subject_dn = utils.get_certificate_subject_dn(client_cert)
except exception.ValidationError:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'failed to get the subject DN from the certificate.')
raise error
try:
cert_issuer_dn = utils.get_certificate_issuer_dn(client_cert)
except exception.ValidationError:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'failed to get the issuer DN from the certificate.')
raise error
client_cert_dn = {}
for key in cert_subject_dn:
client_cert_dn['SSL_CLIENT_SUBJECT_DN_%s' %
key.upper()] = cert_subject_dn.get(key)
for key in cert_issuer_dn:
client_cert_dn['SSL_CLIENT_ISSUER_DN_%s' %
key.upper()] = cert_issuer_dn.get(key)
try:
user = PROVIDERS.identity_api.get_user(client_id)
except exception.UserNotFound:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'the user does not exist. user id: %s.',
client_id)
raise error
project_id = user.get('default_project_id')
if not project_id:
error = exception.OAuth2InvalidClient(
int(http.client.UNAUTHORIZED),
http.client.responses[http.client.UNAUTHORIZED],
_('Client authentication failed.'))
LOG.info('Get OAuth2.0 Access Token API: '
'the user does not have default project. user id: %s.',
client_id)
raise error
user_domain = PROVIDERS.resource_api.get_domain(
user.get('domain_id'))
self._check_mapped_properties(client_cert_dn, user, user_domain)
thumbprint = utils.get_certificate_thumbprint(client_cert)
LOG.debug(f'The mTLS certificate thumbprint: {thumbprint}')
try:
token = PROVIDERS.token_provider_api.issue_token(
user_id=client_id,
method_names=['oauth2_credential'],
project_id=project_id,
thumbprint=thumbprint
)
except exception.Error as error:
if error.code == 401:
error = exception.OAuth2InvalidClient(
error.code, error.title,
str(error))
elif error.code == 400:
error = exception.OAuth2InvalidRequest(
error.code, error.title,
str(error))
else:
error = exception.OAuth2OtherError(
error.code, error.title,
'An unknown error occurred and failed to get an OAuth2.0 '
'access token.')
LOG.exception(error)
raise error
except Exception as error:
error = exception.OAuth2OtherError(
int(http.client.INTERNAL_SERVER_ERROR),
http.client.responses[http.client.INTERNAL_SERVER_ERROR],
str(error))
LOG.exception(error)
raise error
resp = make_response({
'access_token': token.id,
'token_type': 'Bearer',
'expires_in': CONF.token.expiration
})
resp.status = '200 OK'
return resp
class OSAuth2API(ks_flask.APIBase):
_name = 'OS-OAUTH2'

View File

@ -142,5 +142,9 @@ def render_token_response_from_model(token, include_catalog=True):
token_reference['token'][key]['access_rules'] = (
token.application_credential['access_rules']
)
if token.oauth2_thumbprint:
token_reference['token']['oauth2_credential'] = {
'x5t#S256': token.oauth2_thumbprint
}
return token_reference

View File

@ -15,7 +15,7 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import base64
import collections.abc
import contextlib
import grp
@ -25,6 +25,7 @@ import os
import pwd
import uuid
from cryptography import x509
from oslo_log import log
from oslo_serialization import jsonutils
from oslo_utils import reflection
@ -60,6 +61,14 @@ hash_user_password = password_hashing.hash_user_password
check_password = password_hashing.check_password
# NOTE(hiromu): This dict defines alternative DN string for X.509. When
# retriving DN from X.509, converting attributes types that are not listed
# in the RFC4514 to a corresponding alternative DN string.
ATTR_NAME_OVERRIDES = {
x509.NameOID.EMAIL_ADDRESS: "emailAddress",
}
def resource_uuid(value):
"""Convert input to valid UUID hex digits."""
try:
@ -458,6 +467,61 @@ def check_endpoint_url(url):
raise exception.URLValidationError(url=url)
def get_certificate_subject_dn(cert_pem):
"""Get subject DN from the PEM certificate content.
:param str cert_pem: the PEM certificate content
:rtype: JSON data for subject DN
:raises keystone.exception.ValidationError: if the PEM certificate content
is invalid
"""
dn_dict = {}
try:
cert = x509.load_pem_x509_certificate(cert_pem.encode('utf-8'))
for item in cert.subject:
name, value = item.rfc4514_string(
attr_name_overrides=ATTR_NAME_OVERRIDES).split('=')
dn_dict[name] = value
except Exception as error:
LOG.exception(error)
message = _('The certificate content is not PEM format.')
raise exception.ValidationError(message=message)
return dn_dict
def get_certificate_issuer_dn(cert_pem):
"""Get issuer DN from the PEM certificate content.
:param str cert_pem: the PEM certificate content
:rtype: JSON data for issuer DN
:raises keystone.exception.ValidationError: if the PEM certificate content
is invalid
"""
dn_dict = {}
try:
cert = x509.load_pem_x509_certificate(cert_pem.encode('utf-8'))
for item in cert.issuer:
name, value = item.rfc4514_string(
attr_name_overrides=ATTR_NAME_OVERRIDES).split('=')
dn_dict[name] = value
except Exception as error:
LOG.exception(error)
message = _('The certificate content is not PEM format.')
raise exception.ValidationError(message=message)
return dn_dict
def get_certificate_thumbprint(cert_pem):
"""Get certificate thumbprint from the PEM certificate content.
:param str cert_pem: the PEM certificate content
:rtype: certificate thumbprint
"""
thumb_sha256 = hashlib.sha256(cert_pem.encode('ascii')).digest()
thumbprint = base64.urlsafe_b64encode(thumb_sha256).decode('ascii')
return thumbprint
def create_directory(directory, keystone_user_id=None, keystone_group_id=None):
"""Attempt to create a directory if it doesn't exist.

View File

@ -40,6 +40,7 @@ from keystone.conf import jwt_tokens
from keystone.conf import ldap
from keystone.conf import memcache
from keystone.conf import oauth1
from keystone.conf import oauth2
from keystone.conf import policy
from keystone.conf import receipt
from keystone.conf import resource
@ -78,6 +79,7 @@ conf_modules = [
ldap,
memcache,
oauth1,
oauth2,
policy,
receipt,
resource,

52
keystone/conf/oauth2.py Normal file
View File

@ -0,0 +1,52 @@
# Copyright 2022 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_config import cfg
from keystone.conf import utils
oauth2_authn_methods = cfg.ListOpt(
'oauth2_authn_methods',
default=['tls_client_auth', 'client_secret_basic'],
help=utils.fmt("""
The OAuth2.0 authentication method supported by the system when user obtains
an access token through the OAuth2.0 token endpoint. This option can be set to
certificate or secret. If the option is not set, the default value is
certificate. When the option is set to secret, the OAuth2.0 token endpoint
uses client_secret_basic method for authentication, otherwise tls_client_auth
method is used for authentication.
"""))
oauth2_cert_dn_mapping_id = cfg.StrOpt(
'oauth2_cert_dn_mapping_id',
default='oauth2_mapping',
help=utils.fmt("""
Used to define the mapping rule id. When not set, the mapping rule id is
oauth2_mapping.
"""))
GROUP_NAME = __name__.split('.')[-1]
ALL_OPTS = [
oauth2_authn_methods,
oauth2_cert_dn_mapping_id
]
def register_opts(conf):
conf.register_opts(ALL_OPTS, group=GROUP_NAME)
def list_opts():
return {GROUP_NAME: ALL_OPTS}

View File

@ -79,6 +79,9 @@ class TokenModel(object):
self.application_credential_id = None
self.__application_credential = None
self.oauth2_credential_id = None
self.oauth2_thumbprint = None
def __repr__(self):
"""Return string representation of TokenModel."""
desc = ('<%(type)s (audit_id=%(audit_id)s, '
@ -440,6 +443,9 @@ class TokenModel(object):
return roles
def _get_oauth2_credential_roles(self):
return self._get_project_roles()
@property
def roles(self):
if self.system_scoped:

View File

@ -225,6 +225,114 @@ class UtilsTestCase(unit.BaseTestCase):
expected_string_ending = str(time.second) + 'Z'
self.assertTrue(string_time.endswith(expected_string_ending))
def test_get_certificate_subject_dn(self):
cert_pem = unit.create_pem_certificate(
unit.create_dn(
common_name='test',
organization_name='dev',
locality_name='suzhou',
state_or_province_name='jiangsu',
country_name='cn',
user_id='user_id',
domain_component='test.com',
email_address='user@test.com'
))
dn = common_utils.get_certificate_subject_dn(cert_pem)
self.assertEqual('test', dn.get('CN'))
self.assertEqual('dev', dn.get('O'))
self.assertEqual('suzhou', dn.get('L'))
self.assertEqual('jiangsu', dn.get('ST'))
self.assertEqual('cn', dn.get('C'))
self.assertEqual('user_id', dn.get('UID'))
self.assertEqual('test.com', dn.get('DC'))
self.assertEqual('user@test.com', dn.get('emailAddress'))
def test_get_certificate_issuer_dn(self):
root_cert, root_key = unit.create_certificate(
unit.create_dn(
country_name='jp',
state_or_province_name='kanagawa',
locality_name='kawasaki',
organization_name='fujitsu',
organizational_unit_name='test',
common_name='root'
))
cert_pem = unit.create_pem_certificate(
unit.create_dn(
common_name='test',
organization_name='dev',
locality_name='suzhou',
state_or_province_name='jiangsu',
country_name='cn',
user_id='user_id',
domain_component='test.com',
email_address='user@test.com'
), ca=root_cert, ca_key=root_key)
dn = common_utils.get_certificate_subject_dn(cert_pem)
self.assertEqual('test', dn.get('CN'))
self.assertEqual('dev', dn.get('O'))
self.assertEqual('suzhou', dn.get('L'))
self.assertEqual('jiangsu', dn.get('ST'))
self.assertEqual('cn', dn.get('C'))
self.assertEqual('user_id', dn.get('UID'))
self.assertEqual('test.com', dn.get('DC'))
self.assertEqual('user@test.com', dn.get('emailAddress'))
dn = common_utils.get_certificate_issuer_dn(cert_pem)
self.assertEqual('root', dn.get('CN'))
self.assertEqual('fujitsu', dn.get('O'))
self.assertEqual('kawasaki', dn.get('L'))
self.assertEqual('kanagawa', dn.get('ST'))
self.assertEqual('jp', dn.get('C'))
self.assertEqual('test', dn.get('OU'))
def test_get_certificate_subject_dn_not_pem_format(self):
self.assertRaises(
exception.ValidationError,
common_utils.get_certificate_subject_dn,
'MIIEkTCCAnkCFDIzsgpdRGF//5ukMuueXnRxQALhMA0GCSqGSIb3DQEBCwUAMIGC')
def test_get_certificate_issuer_dn_not_pem_format(self):
self.assertRaises(
exception.ValidationError,
common_utils.get_certificate_issuer_dn,
'MIIEkTCCAnkCFDIzsgpdRGF//5ukMuueXnRxQALhMA0GCSqGSIb3DQEBCwUAMIGC')
def test_get_certificate_thumbprint(self):
cert_pem = '''-----BEGIN CERTIFICATE-----
MIIEkTCCAnkCFDIzsgpdRGF//5ukMuueXnRxQALhMA0GCSqGSIb3DQEBCwUAMIGC
MQswCQYDVQQGEwJjbjEQMA4GA1UECAwHamlhbmdzdTEPMA0GA1UEBwwGc3V6aG91
MQ0wCwYDVQQKDARqZnR0MQwwCgYDVQQLDANkZXYxEzARBgNVBAMMCnJvb3QubG9j
YWwxHjAcBgkqhkiG9w0BCQEWD3Rlc3RAcm9vdC5sb2NhbDAeFw0yMjA2MTYwNzM3
NTZaFw0yMjEyMTMwNzM3NTZaMIGGMQswCQYDVQQGEwJjbjEQMA4GA1UECAwHamlh
bmdzdTEPMA0GA1UEBwwGc3V6aG91MQ0wCwYDVQQKDARqZnR0MQwwCgYDVQQLDANk
ZXYxFTATBgNVBAMMDGNsaWVudC5sb2NhbDEgMB4GCSqGSIb3DQEJARYRdGVzdEBj
bGllbnQubG9jYWwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCah1Uz
2OVbk8zLslxxGV+AR6FTy9b/VoinmB6A0jJA1Zz2D6rsjN2S5xQ5wHIO2WSVX9Ry
SonOmeZZqRA9faNJcNNcrBhJICScAhMGHCuli3EUMry/6xK0OYHGgI2X6mcTaIjv
tFKHO1BCb5YGdNBa+ff+ncTeVX/PeN3nKjA4xvQb9JZxJTgY0JVhledbaoepFSdW
EFW0nbUF+8lj1gCo5E4cAX1eTcUKs43FnWGCJcJT6FB1vP9x8e4h9p0RWbb9GMrU
DXKbzF5e28qIiCkYHv2/A/G/J+aeg2K4Cbqy+8908I5BdWZEsJBhWJ0+CEtC3n91
fU6dnAyipO496aa/AgMBAAEwDQYJKoZIhvcNAQELBQADggIBABoOOmLrWNlQzodS
n2wfkiF0Lz+pj3FKFPz3sYUYWkAiKXU/6RRu1Md7INRo0MFau4iAN8Raq4JFdbnU
HRN9G/UU58ETqi/8cYfOA2+MHHRif1Al9YSvTgHQa6ljZPttGeigOqmGlovPd+7R
vLXlKtcr5XBVk9pWPmVpwtAN3bMVlphgEqBO26Ff9J3G5PaNQ6UdpwXC19mRqk6r
BUsFBRwy7EeeGNy8DvoHTJfMc2JUbLjesSMOmIkaOGbhe327iRd/GJe4dO91+prE
HNWVR/bVoGiUZvSLPqrwU173XbdNd6yMKC+fULICI34eaWDe1zHrg9XdRxtessUx
OyJw5bgH09lOs8DSYXjFyx5lDxtERKHaLRgpSNd5foQO/mHiegC2qmdtxqKyOwub
V/h6vziDsFZfciwmo6iw3ZpdBvjbYqw32joURQ1IVh1naY6ZzMwq/PsyYVhMYUNB
XYPKvm68YfKuYmpwF7Z5Wll4EWm5DTq1dbmjdo+OQsMyiwWepWE0WV7Ng+AEbTqP
/akzUXt/AEbbBpZskB6v5q/YOcglWuAQVXs2viguyDvOQVbEB7JKDi4xzlZg3kQP
apjt17fip7wQi2jJkwdyAqvrdi/xLhK5+6BSo04lNc8sGZ9wToIoNkgv0cG+BrVU
4cJHNiTQl8bxfSgwemgSYnnyXM4k
-----END CERTIFICATE-----'''
thumbprint = common_utils.get_certificate_thumbprint(cert_pem)
self.assertEqual('dMmoJKE9MIJK9VcyahYCb417JDhDfdtTiq_krco8-tk=',
thumbprint)
class ServiceHelperTests(unit.BaseTestCase):

View File

@ -28,6 +28,10 @@ import socket
import sys
import uuid
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography import x509
import fixtures
import flask
from flask import testing as flask_testing
@ -433,6 +437,77 @@ def new_totp_credential(user_id, project_id=None, blob=None):
return credential
def create_dn(
common_name=None,
locality_name=None,
state_or_province_name=None,
organization_name=None,
organizational_unit_name=None,
country_name=None,
street_address=None,
domain_component=None,
user_id=None,
email_address=None,
):
oid = x509.NameOID
attr = x509.NameAttribute
dn = []
if common_name:
dn.append(attr(oid.COMMON_NAME, common_name))
if locality_name:
dn.append(attr(oid.LOCALITY_NAME, locality_name))
if state_or_province_name:
dn.append(attr(oid.STATE_OR_PROVINCE_NAME, state_or_province_name))
if organization_name:
dn.append(attr(oid.ORGANIZATION_NAME, organization_name))
if organizational_unit_name:
dn.append(attr(oid.ORGANIZATIONAL_UNIT_NAME, organizational_unit_name))
if country_name:
dn.append(attr(oid.COUNTRY_NAME, country_name))
if street_address:
dn.append(attr(oid.STREET_ADDRESS, street_address))
if domain_component:
dn.append(attr(oid.DOMAIN_COMPONENT, domain_component))
if user_id:
dn.append(attr(oid.USER_ID, user_id))
if email_address:
dn.append(attr(oid.EMAIL_ADDRESS, email_address))
return x509.Name(dn)
def update_dn(dn1, dn2):
dn1_attrs = {attr.oid: attr for attr in dn1}
dn2_attrs = {attr.oid: attr for attr in dn2}
dn1_attrs.update(dn2_attrs)
return x509.Name([attr for attr in dn1_attrs.values()])
def create_certificate(subject_dn, ca=None, ca_key=None):
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
issuer = ca.subject if ca else subject_dn
if not ca_key:
ca_key = private_key
today = datetime.datetime.today()
cert = x509.CertificateBuilder(
issuer_name=issuer,
subject_name=subject_dn,
public_key=private_key.public_key(),
serial_number=x509.random_serial_number(),
not_valid_before=today,
not_valid_after=today + datetime.timedelta(365, 0, 0),
).sign(ca_key, hashes.SHA256())
return cert, private_key
def create_pem_certificate(subject_dn, ca=None, ca_key=None):
cert, _ = create_certificate(subject_dn, ca=ca, ca_key=ca_key)
return cert.public_bytes(Encoding.PEM).decode('ascii')
def new_application_credential_ref(roles=None,
name=None,
expires=None,

View File

@ -19,8 +19,10 @@ import itertools
import operator
import re
from unittest import mock
from urllib import parse
import uuid
from cryptography.hazmat.primitives.serialization import Encoding
import freezegun
import http.client
from oslo_serialization import jsonutils as json
@ -2645,6 +2647,187 @@ class TokenAPITests(object):
r = self._validate_token(token, allow_expired=True)
self.assertValidProjectScopedTokenResponse(r)
def _create_project_user(self):
new_domain_ref = unit.new_domain_ref()
PROVIDERS.resource_api.create_domain(
new_domain_ref['id'], new_domain_ref
)
new_project_ref = unit.new_project_ref(domain_id=self.domain_id)
PROVIDERS.resource_api.create_project(
new_project_ref['id'], new_project_ref
)
new_user = unit.create_user(PROVIDERS.identity_api,
domain_id=new_domain_ref['id'],
project_id=new_project_ref['id'])
PROVIDERS.assignment_api.create_grant(
self.role['id'],
user_id=new_user['id'],
project_id=new_project_ref['id'])
return new_user, new_domain_ref, new_project_ref
def _create_certificates(self,
root_dn=None,
server_dn=None,
client_dn=None):
root_subj = unit.create_dn(
country_name='jp',
state_or_province_name='kanagawa',
locality_name='kawasaki',
organization_name='fujitsu',
organizational_unit_name='test',
common_name='root'
)
if root_dn:
root_subj = unit.update_dn(root_subj, root_dn)
root_cert, root_key = unit.create_certificate(root_subj)
keystone_subj = unit.create_dn(
country_name='jp',
state_or_province_name='kanagawa',
locality_name='kawasaki',
organization_name='fujitsu',
organizational_unit_name='test',
common_name='keystone.local'
)
if server_dn:
keystone_subj = unit.update_dn(keystone_subj, server_dn)
ks_cert, ks_key = unit.create_certificate(
keystone_subj, ca=root_cert, ca_key=root_key)
client_subj = unit.create_dn(
country_name='jp',
state_or_province_name='kanagawa',
locality_name='kawasaki',
organization_name='fujitsu',
organizational_unit_name='test',
common_name='client'
)
if client_dn:
client_subj = unit.update_dn(client_subj, client_dn)
client_cert, client_key = unit.create_certificate(
client_subj, ca=root_cert, ca_key=root_key)
return root_cert, root_key, ks_cert, ks_key, client_cert, client_key
def _get_cert_content(self, cert):
return cert.public_bytes(Encoding.PEM).decode('ascii')
def _get_oauth2_access_token(self, client_id, client_cert_content,
expected_status=http.client.OK):
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
}
data = {
'grant_type': 'client_credentials',
'client_id': client_id
}
extra_environ = {
'SSL_CLIENT_CERT': client_cert_content
}
data = parse.urlencode(data).encode()
resp = self.post(
'/OS-OAUTH2/token',
headers=headers,
noauth=True,
convert=False,
body=data,
environ=extra_environ,
expected_status=expected_status)
return resp
def _create_mapping(self):
mapping = {
'id': 'oauth2_mapping',
'rules': [
{
'local': [
{
'user': {
'name': '{0}',
'id': '{1}',
'email': '{2}',
'domain': {
'name': '{3}',
'id': '{4}'
}
}
}
],
'remote': [
{
'type': 'SSL_CLIENT_SUBJECT_DN_CN'
},
{
'type': 'SSL_CLIENT_SUBJECT_DN_UID'
},
{
'type': 'SSL_CLIENT_SUBJECT_DN_EMAILADDRESS'
},
{
'type': 'SSL_CLIENT_SUBJECT_DN_O'
},
{
'type': 'SSL_CLIENT_SUBJECT_DN_DC'
},
{
'type': 'SSL_CLIENT_ISSUER_DN_CN',
'any_one_of': [
'root'
]
}
]
}
]
}
PROVIDERS.federation_api.create_mapping(mapping['id'], mapping)
def test_verify_oauth2_token_project_scope_ok(self):
cache_on_issue = CONF.token.cache_on_issue
caching = CONF.token.caching
self._create_mapping()
user, user_domain, _ = self._create_project_user()
*_, client_cert, _ = self._create_certificates(
root_dn=unit.create_dn(
common_name='root'
),
client_dn=unit.create_dn(
common_name=user['name'],
user_id=user['id'],
email_address=user['email'],
organization_name=user_domain['name'],
domain_component=user_domain['id']
)
)
cert_content = self._get_cert_content(client_cert)
CONF.token.cache_on_issue = False
CONF.token.caching = False
resp = self._get_oauth2_access_token(user['id'], cert_content)
json_resp = json.loads(resp.body)
self.assertIn('access_token', json_resp)
self.assertEqual('Bearer', json_resp['token_type'])
self.assertEqual(3600, json_resp['expires_in'])
verify_resp = self.get(
'/auth/tokens',
headers={
'X-Subject-Token': json_resp['access_token'],
'X-Auth-Token': json_resp['access_token']
},
expected_status=http.client.OK)
self.assertIn('token', verify_resp.result)
self.assertIn('oauth2_credential', verify_resp.result['token'])
self.assertIn('roles', verify_resp.result['token'])
self.assertIn('project', verify_resp.result['token'])
self.assertIn('catalog', verify_resp.result['token'])
check_oauth2 = verify_resp.result['token']['oauth2_credential']
self.assertEqual(utils.get_certificate_thumbprint(cert_content),
check_oauth2['x5t#S256'])
CONF.token.cache_on_issue = cache_on_issue
CONF.token.caching = caching
class TokenDataTests(object):
"""Test the data in specific token types."""

File diff suppressed because it is too large Load Diff

View File

@ -317,7 +317,7 @@ class TestTokenFormatter(unit.TestCase):
(user_id, methods, audit_ids, system, domain_id, project_id, trust_id,
federated_group_ids, identity_provider_id, protocol_id,
access_token_id, app_cred_id, issued_at,
access_token_id, app_cred_id, thumbprint, issued_at,
expires_at) = token_formatter.validate_token(token)
self.assertEqual(exp_user_id, user_id)
@ -352,7 +352,7 @@ class TestTokenFormatter(unit.TestCase):
(user_id, methods, audit_ids, system, domain_id, project_id, trust_id,
federated_group_ids, identity_provider_id, protocol_id,
access_token_id, app_cred_id, issued_at,
access_token_id, app_cred_id, thumbprint, issued_at,
expires_at) = token_formatter.validate_token(token)
self.assertEqual(exp_user_id, user_id)
@ -473,7 +473,7 @@ class TestPayloads(unit.TestCase):
exp_trust_id=None, exp_federated_group_ids=None,
exp_identity_provider_id=None, exp_protocol_id=None,
exp_access_token_id=None, exp_app_cred_id=None,
encode_ids=False):
encode_ids=False, exp_thumbprint=None):
def _encode_id(value):
if value is not None and str(value) and encode_ids:
return value.encode('utf-8')
@ -496,12 +496,14 @@ class TestPayloads(unit.TestCase):
_encode_id(exp_identity_provider_id),
exp_protocol_id,
_encode_id(exp_access_token_id),
_encode_id(exp_app_cred_id))
_encode_id(exp_app_cred_id),
exp_thumbprint)
(user_id, methods, system, project_id,
domain_id, expires_at, audit_ids,
trust_id, federated_group_ids, identity_provider_id, protocol_id,
access_token_id, app_cred_id) = payload_class.disassemble(payload)
access_token_id, app_cred_id,
thumbprint) = payload_class.disassemble(payload)
self.assertEqual(exp_user_id, user_id)
self.assertEqual(exp_methods, methods)

View File

@ -154,8 +154,8 @@ class Manager(manager.Manager):
def _validate_token(self, token_id):
(user_id, methods, audit_ids, system, domain_id,
project_id, trust_id, federated_group_ids, identity_provider_id,
protocol_id, access_token_id, app_cred_id, issued_at,
expires_at) = self.driver.validate_token(token_id)
protocol_id, access_token_id, app_cred_id, thumbprint,
issued_at, expires_at) = self.driver.validate_token(token_id)
token = token_model.TokenModel()
token.user_id = user_id
@ -169,6 +169,7 @@ class Manager(manager.Manager):
token.trust_id = trust_id
token.access_token_id = access_token_id
token.application_credential_id = app_cred_id
token.oauth2_thumbprint = thumbprint
token.expires_at = expires_at
if federated_group_ids is not None:
token.is_federated = True
@ -221,7 +222,7 @@ class Manager(manager.Manager):
def issue_token(self, user_id, method_names, expires_at=None,
system=None, project_id=None, domain_id=None,
auth_context=None, trust_id=None, app_cred_id=None,
parent_audit_id=None):
thumbprint=None, parent_audit_id=None):
# NOTE(lbragstad): Grab a blank token object and use composition to
# build the token according to the authentication and authorization
@ -235,6 +236,7 @@ class Manager(manager.Manager):
token.trust_id = trust_id
token.application_credential_id = app_cred_id
token.audit_id = random_urlsafe_str()
token.oauth2_thumbprint = thumbprint
token.parent_audit_id = parent_audit_id
if auth_context:

View File

@ -44,6 +44,7 @@ class Provider(object, metaclass=abc.ABCMeta):
``protocol_id`` unique ID of the protocol used to obtain the token
``access_token_id`` the unique ID of the access_token for OAuth1 tokens
``app_cred_id`` the unique ID of the application credential
``param thumbprint`` thumbprint of the certificate for OAuth2.0 mTLS
``issued_at`` a datetime object of when the token was minted
``expires_at`` a datetime object of when the token expires

View File

@ -58,6 +58,8 @@ class Provider(base.Provider):
return tf.FederatedUnscopedPayload
elif token.application_credential_id:
return tf.ApplicationCredentialScopedPayload
elif token.oauth2_thumbprint:
return tf.Oauth2CredentialsScopedPayload
elif token.project_scoped:
return tf.ProjectScopedPayload
elif token.domain_scoped:
@ -83,7 +85,8 @@ class Provider(base.Provider):
identity_provider_id=token.identity_provider_id,
protocol_id=token.protocol_id,
access_token_id=token.access_token_id,
app_cred_id=token.application_credential_id
app_cred_id=token.application_credential_id,
thumbprint=token.oauth2_thumbprint,
)
creation_datetime_obj = self.token_formatter.creation_time(token_id)
issued_at = ks_utils.isotime(

View File

@ -70,7 +70,8 @@ class Provider(base.Provider):
identity_provider_id=token.identity_provider_id,
protocol_id=token.protocol_id,
access_token_id=token.access_token_id,
app_cred_id=token.application_credential_id
app_cred_id=token.application_credential_id,
thumbprint=token.oauth2_thumbprint,
)
def validate_token(self, token_id):
@ -106,7 +107,8 @@ class JWSFormatter(object):
system=None, domain_id=None, project_id=None,
trust_id=None, federated_group_ids=None,
identity_provider_id=None, protocol_id=None,
access_token_id=None, app_cred_id=None):
access_token_id=None, app_cred_id=None,
thumbprint=None):
issued_at = utils.isotime(subsecond=True)
issued_at_int = self._convert_time_string_to_int(issued_at)
@ -128,7 +130,8 @@ class JWSFormatter(object):
'openstack_idp_id': identity_provider_id,
'openstack_protocol_id': protocol_id,
'openstack_access_token_id': access_token_id,
'openstack_app_cred_id': app_cred_id
'openstack_app_cred_id': app_cred_id,
'openstack_thumbprint': thumbprint,
}
# NOTE(lbragstad): Calling .items() on a dictionary in python 2 returns
@ -164,6 +167,7 @@ class JWSFormatter(object):
protocol_id = payload.get('openstack_protocol_id', None)
access_token_id = payload.get('openstack_access_token_id', None)
app_cred_id = payload.get('openstack_app_cred_id', None)
thumbprint = payload.get('openstack_thumbprint', None)
issued_at = self._convert_time_int_to_string(issued_at_int)
expires_at = self._convert_time_int_to_string(expires_at_int)
@ -171,7 +175,7 @@ class JWSFormatter(object):
return (
user_id, methods, audit_ids, system, domain_id, project_id,
trust_id, federated_group_ids, identity_provider_id, protocol_id,
access_token_id, app_cred_id, issued_at, expires_at
access_token_id, app_cred_id, thumbprint, issued_at, expires_at,
)
def _decode_token_from_id(self, token_id):

View File

@ -137,14 +137,14 @@ class TokenFormatter(object):
methods=None, system=None, domain_id=None,
project_id=None, trust_id=None, federated_group_ids=None,
identity_provider_id=None, protocol_id=None,
access_token_id=None, app_cred_id=None):
access_token_id=None, app_cred_id=None,
thumbprint=None):
"""Given a set of payload attributes, generate a Fernet token."""
version = payload_class.version
payload = payload_class.assemble(
user_id, methods, system, project_id, domain_id, expires_at,
audit_ids, trust_id, federated_group_ids, identity_provider_id,
protocol_id, access_token_id, app_cred_id
)
protocol_id, access_token_id, app_cred_id, thumbprint)
versioned_payload = (version,) + payload
serialized_payload = msgpack.packb(versioned_payload)
@ -187,7 +187,8 @@ class TokenFormatter(object):
(user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id) = payload_class.disassemble(payload)
app_cred_id, thumbprint) = (
payload_class.disassemble(payload))
break
else:
# If the token_format is not recognized, raise ValidationError.
@ -211,8 +212,8 @@ class TokenFormatter(object):
return (user_id, methods, audit_ids, system, domain_id, project_id,
trust_id, federated_group_ids, identity_provider_id,
protocol_id, access_token_id, app_cred_id, issued_at,
expires_at)
protocol_id, access_token_id, app_cred_id, thumbprint,
issued_at, expires_at)
class BasePayload(object):
@ -223,7 +224,7 @@ class BasePayload(object):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
"""Assemble the payload of a token.
:param user_id: identifier of the user in the token request
@ -239,6 +240,7 @@ class BasePayload(object):
:param protocol_id: federated protocol used for authentication
:param access_token_id: ID of the secret in OAuth1 authentication
:param app_cred_id: ID of the application credential in effect
:param thumbprint: thumbprint of the certificate in OAuth2 mTLS
:returns: the payload of a token
"""
@ -377,7 +379,7 @@ class UnscopedPayload(BasePayload):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
b_user_id = cls.attempt_convert_uuid_hex_to_bytes(user_id)
methods = auth_plugins.convert_method_list_to_integer(methods)
expires_at_int = cls._convert_time_string_to_float(expires_at)
@ -401,10 +403,11 @@ class UnscopedPayload(BasePayload):
protocol_id = None
access_token_id = None
app_cred_id = None
thumbprint = None
return (user_id, methods, system, project_id, domain_id,
expires_at_str, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id)
app_cred_id, thumbprint)
class DomainScopedPayload(BasePayload):
@ -414,7 +417,7 @@ class DomainScopedPayload(BasePayload):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
b_user_id = cls.attempt_convert_uuid_hex_to_bytes(user_id)
methods = auth_plugins.convert_method_list_to_integer(methods)
try:
@ -455,10 +458,11 @@ class DomainScopedPayload(BasePayload):
protocol_id = None
access_token_id = None
app_cred_id = None
thumbprint = None
return (user_id, methods, system, project_id, domain_id,
expires_at_str, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id)
app_cred_id, thumbprint)
class ProjectScopedPayload(BasePayload):
@ -468,7 +472,7 @@ class ProjectScopedPayload(BasePayload):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
b_user_id = cls.attempt_convert_uuid_hex_to_bytes(user_id)
methods = auth_plugins.convert_method_list_to_integer(methods)
b_project_id = cls.attempt_convert_uuid_hex_to_bytes(project_id)
@ -494,10 +498,11 @@ class ProjectScopedPayload(BasePayload):
protocol_id = None
access_token_id = None
app_cred_id = None
thumbprint = None
return (user_id, methods, system, project_id, domain_id,
expires_at_str, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id)
app_cred_id, thumbprint)
class TrustScopedPayload(BasePayload):
@ -507,7 +512,7 @@ class TrustScopedPayload(BasePayload):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
b_user_id = cls.attempt_convert_uuid_hex_to_bytes(user_id)
methods = auth_plugins.convert_method_list_to_integer(methods)
b_project_id = cls.attempt_convert_uuid_hex_to_bytes(project_id)
@ -536,10 +541,11 @@ class TrustScopedPayload(BasePayload):
protocol_id = None
access_token_id = None
app_cred_id = None
thumbprint = None
return (user_id, methods, system, project_id, domain_id,
expires_at_str, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id)
app_cred_id, thumbprint)
class FederatedUnscopedPayload(BasePayload):
@ -559,7 +565,7 @@ class FederatedUnscopedPayload(BasePayload):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
b_user_id = cls.attempt_convert_uuid_hex_to_bytes(user_id)
methods = auth_plugins.convert_method_list_to_integer(methods)
b_group_ids = list(map(cls.pack_group_id, federated_group_ids))
@ -590,9 +596,10 @@ class FederatedUnscopedPayload(BasePayload):
trust_id = None
access_token_id = None
app_cred_id = None
thumbprint = None
return (user_id, methods, system, project_id, domain_id,
expires_at_str, audit_ids, trust_id, group_ids, idp_id,
protocol_id, access_token_id, app_cred_id)
protocol_id, access_token_id, app_cred_id, thumbprint)
class FederatedScopedPayload(FederatedUnscopedPayload):
@ -602,7 +609,7 @@ class FederatedScopedPayload(FederatedUnscopedPayload):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
b_user_id = cls.attempt_convert_uuid_hex_to_bytes(user_id)
methods = auth_plugins.convert_method_list_to_integer(methods)
b_scope_id = cls.attempt_convert_uuid_hex_to_bytes(
@ -641,9 +648,10 @@ class FederatedScopedPayload(FederatedUnscopedPayload):
trust_id = None
access_token_id = None
app_cred_id = None
thumbprint = None
return (user_id, methods, system, project_id, domain_id,
expires_at_str, audit_ids, trust_id, group_ids, idp_id,
protocol_id, access_token_id, app_cred_id)
protocol_id, access_token_id, app_cred_id, thumbprint)
class FederatedProjectScopedPayload(FederatedScopedPayload):
@ -661,7 +669,7 @@ class OauthScopedPayload(BasePayload):
def assemble(cls, user_id, methods, system, project_id, domain_id,
expires_at, audit_ids, trust_id, federated_group_ids,
identity_provider_id, protocol_id, access_token_id,
app_cred_id):
app_cred_id, thumbprint):
b_user_id = cls.attempt_convert_uuid_hex_to_bytes(user_id)
methods = auth_plugins.convert_method_list_to_integer(methods)
b_project_id = cls.attempt_convert_uuid_hex_to_bytes(project_id)
@ -692,11 +700,12 @@ class OauthScopedPayload(BasePayload):
identity_provider_id = None
protocol_id = None
app_cred_id = None
thumbprint = None
return (user_id, methods, system, project_id, domain_id,
expires_at_str, audit_ids</