updated Keycloak_auth to use public key for validation

Previously there were 2 ways to validate the access_token
1. define auth_url, which will send the access_token to
keycloak for validation
2. define absolute user_info_endpoint_url to validate the
 custom url

In this patch we have removed the 1st option to validate
 everytime with keycloak instead we are taking keyclaok certs
 and constructing public key with it and then validating using
 access_token using public key and iss(optional).
We are keeping the public key in cache to not to request
 keycloak repeatedly.

Change-Id: Ie0551c2f9f8a37debd50e7aebcf35f7143db44f9
This commit is contained in:
kushalagrawal 2019-08-06 16:06:29 +05:30
parent 5815d94781
commit 47fd843000
2 changed files with 127 additions and 56 deletions

View File

@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from cachetools import cached
from cachetools import LRUCache
import json
import jwt
import memcache
from oslo_config import cfg
@ -25,6 +28,7 @@ import webob.dec
from glare.common import exception
from glare.common import utils
from glare.i18n import _
from jwt.algorithms import RSAAlgorithm
LOG = logging.getLogger(__name__)
@ -70,6 +74,15 @@ keycloak_oidc_opts = [
'tokens, the middleware caches previously-seen tokens '
'for a configurable duration (in seconds).'
),
cfg.StrOpt(
'public_cert_url',
default="/realms/%s/protocol/openid-connect/certs",
help="URL to get the public key for perticualar realm"
),
cfg.StrOpt(
'keycloak_iss',
help="keycloak issuer(iss) url Ex: https://ip_add:port/auth/realms/%s"
)
]
CONF = cfg.CONF
@ -81,67 +94,34 @@ class KeycloakAuthMiddleware(base_middleware.Middleware):
super(KeycloakAuthMiddleware, self).__init__(application=app)
mcserv_url = CONF.keycloak_oidc.memcached_server
self.mcclient = memcache.Client(mcserv_url) if mcserv_url else None
self.certfile = CONF.keycloak_oidc.certfile
self.keyfile = CONF.keycloak_oidc.keyfile
self.cafile = CONF.keycloak_oidc.cafile or utils.get_system_ca_file()
self.insecure = CONF.keycloak_oidc.insecure
def authenticate(self, access_token, realm_name):
def authenticate(self, access_token, realm_name, audience):
info = None
if self.mcclient:
info = self.mcclient.get(access_token)
user_info_endpoint_url = CONF.keycloak_oidc.user_info_endpoint_url
if info is None and user_info_endpoint_url:
if info is None:
if user_info_endpoint_url.startswith(('http://', 'https://')):
url = user_info_endpoint_url
info = self.send_request_to_auth_server(
url=user_info_endpoint_url, access_token=access_token)
else:
url_template = CONF.keycloak_oidc.auth_url + \
CONF.keycloak_oidc.user_info_endpoint_url
url = url_template % realm_name
verify = None
if urllib.parse.urlparse(url).scheme == "https":
verify = False if self.insecure else self.cafile
cert = (self.certfile, self.keyfile) \
if self.certfile and self.keyfile else None
try:
resp = requests.get(
url,
headers={"Authorization": "Bearer %s" % access_token},
verify=verify,
cert=cert
)
except requests.ConnectionError:
msg = _("Can't connect to keycloak server with address '%s'."
) % CONF.keycloak_oidc.auth_url
LOG.error(msg)
raise exception.GlareException(message=msg)
if resp.status_code == 400:
raise exception.BadRequest(message=resp.text)
if resp.status_code == 401:
LOG.warning("HTTP response from OIDC provider:"
" [%s] with WWW-Authenticate: [%s]",
pprint.pformat(resp.text),
resp.headers.get("WWW-Authenticate"))
raise exception.Unauthorized(message=resp.text)
if resp.status_code == 403:
raise exception.Forbidden(message=resp.text)
elif resp.status_code > 400:
raise exception.GlareException(message=resp.text)
if self.mcclient:
self.mcclient.set(access_token, resp.json(),
time=CONF.keycloak_oidc.token_cache_time)
info = resp.json()
LOG.debug("HTTP response from OIDC provider: %s",
pprint.pformat(info))
public_key = self.get_public_key(realm_name)
keycloak_iss = None
try:
if CONF.keycloak_oidc.keycloak_iss:
keycloak_iss = \
CONF.keycloak_oidc.keycloak_iss % realm_name
jwt.decode(access_token, public_key, audience=audience,
issuer=keycloak_iss, algorithms=['RS256'],
verify=True)
except Exception as e:
LOG.error("Exception in access_token validation %s", e)
raise exception.Unauthorized()
return info
@webob.dec.wsgify
@ -163,14 +143,62 @@ class KeycloakAuthMiddleware(base_middleware.Middleware):
# Get user realm from parsed token
# Format is "iss": "http://<host>:<port>/auth/realms/<realm_name>",
__, __, realm_name = decoded['iss'].strip().rpartition('/realms/')
audience = decoded['aud']
# Get roles from from parsed token
roles = ','.join(decoded['realm_access']['roles']) \
if 'realm_access' in decoded else ''
self.authenticate(access_token, realm_name)
self.authenticate(access_token, realm_name, audience)
request.headers["X-Identity-Status"] = "Confirmed"
request.headers["X-Project-Id"] = realm_name
request.headers["X-Roles"] = roles
return request.get_response(self.application)
@cached(LRUCache(maxsize=32))
def get_public_key(self, realm_name):
keycloak_key_url = CONF.keycloak_oidc.auth_url + \
CONF.keycloak_oidc.public_cert_url % realm_name
response_json = self.send_request_to_auth_server(keycloak_key_url)
public_key = RSAAlgorithm.from_jwk(json.dumps(
response_json.get("keys")[0]))
return public_key
def send_request_to_auth_server(self, url, access_token=None):
verify = None
if urllib.parse.urlparse(url).scheme == "https":
verify = False if self.insecure else self.cafile
cert = (self.certfile, self.keyfile) \
if self.certfile and self.keyfile else None
headers = {}
if access_token:
headers["Authorization"] = "Bearer %s" % access_token
try:
resp = requests.get(
url,
headers=headers,
verify=verify,
cert=cert
)
except requests.ConnectionError as e:
msg = _("Can't connect to keycloak server with address '%s'."
) % url
LOG.error(msg, e)
raise exception.GlareException(message=msg)
if resp.status_code == 400:
raise exception.BadRequest(message=resp.text)
if resp.status_code == 401:
LOG.warning("HTTP response from OIDC provider:"
" [%s] with WWW-Authenticate: [%s]",
pprint.pformat(resp.text),
resp.headers.get("WWW-Authenticate"))
raise exception.Unauthorized(message=resp.text)
if resp.status_code == 403:
raise exception.Forbidden(message=resp.text)
elif resp.status_code > 400:
raise exception.GlareException(message=resp.text)
return resp.json()

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import mock
import requests
import webob
@ -35,13 +36,27 @@ class TestKeycloakAuthMiddleware(base.BaseTestCase):
def test_header_parsing(self, mocked_get):
token = {
"iss": "http://localhost:8080/auth/realms/my_realm",
"aud": "openstack",
"realm_access": {
"roles": ["role1", "role2"]
}
}
mocked_resp = mock.Mock()
mocked_resp.status_code = 200
mocked_resp.json.return_value = '{"user": "mike"}'
mocked_resp.json.return_value = json.loads("""
{
"keys": [
{
"kid": "FJ86GcF3jTbNLOco4NvZkUCIUmfYCqoqtOQeMfbhNlE",
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"n": "q1awrk7QK24Gmcy9Yb4dMbS-ZnO6",
"e": "AQAB"
}
]
}
""")
mocked_get.return_value = mocked_resp
req = self._build_request(token)
@ -58,7 +73,10 @@ class TestKeycloakAuthMiddleware(base.BaseTestCase):
@mock.patch("requests.get")
def test_no_realm_access(self, mocked_get):
self.config(user_info_endpoint_url='https://127.0.0.1:9080',
group='keycloak_oidc')
token = {
"aud": "openstack",
"iss": "http://localhost:8080/auth/realms/my_realm",
}
mocked_resp = mock.Mock()
@ -79,12 +97,14 @@ class TestKeycloakAuthMiddleware(base.BaseTestCase):
@mock.patch("requests.get")
def test_server_unauthorized(self, mocked_get):
self.config(user_info_endpoint_url='https://127.0.0.1:9080',
group='keycloak_oidc')
token = {
"aud": "openstack",
"iss": "http://localhost:8080/auth/realms/my_realm",
}
mocked_resp = mock.Mock()
mocked_resp.status_code = 401
mocked_resp.json.return_value = '{"user": "mike"}'
mocked_get.return_value = mocked_resp
req = self._build_request(token)
@ -93,12 +113,14 @@ class TestKeycloakAuthMiddleware(base.BaseTestCase):
@mock.patch("requests.get")
def test_server_forbidden(self, mocked_get):
self.config(user_info_endpoint_url='https://127.0.0.1:9080',
group='keycloak_oidc')
token = {
"aud": "openstack",
"iss": "http://localhost:8080/auth/realms/my_realm",
}
mocked_resp = mock.Mock()
mocked_resp.status_code = 403
mocked_resp.json.return_value = '{"user": "mike"}'
mocked_get.return_value = mocked_resp
req = self._build_request(token)
@ -108,11 +130,12 @@ class TestKeycloakAuthMiddleware(base.BaseTestCase):
@mock.patch("requests.get")
def test_server_exception(self, mocked_get):
token = {
"iss": "http://localhost:8080/auth/realms/my_realm",
"aud": "openstack",
"iss": "http://localhost:8080/auth/realms/my_realm"
}
mocked_resp = mock.Mock()
mocked_resp.status_code = 500
mocked_resp.json.return_value = '{"user": "mike"}'
mocked_resp.json.return_value = "Internal Server Error"
mocked_get.return_value = mocked_resp
req = self._build_request(token)
@ -123,6 +146,7 @@ class TestKeycloakAuthMiddleware(base.BaseTestCase):
@mock.patch("requests.get")
def test_connection_error(self, mocked_get):
token = {
"aud": "openstack",
"iss": "http://localhost:8080/auth/realms/my_realm",
}
mocked_get.side_effect = requests.ConnectionError
@ -137,16 +161,35 @@ class TestKeycloakAuthMiddleware(base.BaseTestCase):
self.config(user_info_endpoint_url='',
group='keycloak_oidc')
token = {
"aud": "openstack",
"iss": "http://localhost:8080/auth/realms/my_realm",
"realm_access": {
"roles": ["role1", "role2"]
}
}
mocked_resp = mock.Mock()
mocked_resp.status_code = 200
mocked_resp.json.return_value = json.loads("""
{
"keys": [
{
"kid": "FJ86GcF3jTbNLOco4NvZkUCIUmfYCqoqtOQeMfbhNlE",
"kty": "RSA",
"alg": "RS256",
"use": "sig",
"n": "q1awrk7QK24Gmcy9Yb4dMbS-ZnO6",
"e": "AQAB"
}
]
}
""")
mocked_get.return_value = mocked_resp
req = self._build_request(token)
with mock.patch("jwt.decode", return_value=token):
self._build_middleware()(req)
self.assertEqual("Confirmed", req.headers["X-Identity-Status"])
self.assertEqual("my_realm", req.headers["X-Project-Id"])
self.assertEqual("role1,role2", req.headers["X-Roles"])
self.assertEqual(0, mocked_get.call_count)
self.assertEqual(1, mocked_get.call_count)