Fixed a problem in parsing metadata extensions.

This commit is contained in:
Roland Hedberg
2015-11-18 10:38:39 +01:00
parent 0218b0b064
commit 2ce425c84c
6 changed files with 167 additions and 83 deletions

View File

@@ -979,7 +979,7 @@ def extension_elements_to_elements(extension_elements, schemas):
if isinstance(schemas, list):
pass
elif isinstance(schemas, dict):
schemas = schemas.values()
schemas = list(schemas.values())
else:
return res

View File

@@ -1,19 +1,20 @@
from __future__ import print_function
import hashlib
import logging
import os
import sys
import json
import six
import requests
import six
from hashlib import sha1
from os.path import isfile, join
from saml2.httpbase import HTTPBase
from saml2.extension.idpdisc import BINDING_DISCO
from saml2.extension.idpdisc import DiscoveryResponse
from saml2.md import EntitiesDescriptor
from saml2.mdie import to_dict
from saml2 import md
from saml2 import samlp
from saml2 import SAMLError
@@ -67,6 +68,20 @@ ENTITY_CATEGORY_SUPPORT = "http://macedir.org/entity-category-support"
# ---------------------------------------------------
def load_extensions():
from saml2 import extension
import pkgutil
package = extension
prefix = package.__name__ + "."
ext_map = {}
for importer, modname, ispkg in pkgutil.iter_modules(package.__path__,
prefix):
module = __import__(modname, fromlist="dummy")
ext_map[module.NAMESPACE] = module
return ext_map
def destinations(srvs):
return [s["location"] for s in srvs]
@@ -564,8 +579,8 @@ class InMemoryMetaData(MetaData):
return True
node_name = self.node_name \
or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
md.EntitiesDescriptor.c_tag)
or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
md.EntitiesDescriptor.c_tag)
if self.security.verify_signature(
txt, node_name=node_name, cert_file=self.cert):
@@ -705,27 +720,31 @@ class MetaDataMDX(InMemoryMetaData):
""" Uses the md protocol to fetch entity information
"""
def __init__(self, entity_transform, onts, attrc, url, security, cert,
http, **kwargs):
@staticmethod
def sha1_entity_transform(entity_id):
return "{{sha1}}{}".format(
hashlib.sha1(entity_id.encode("utf-8")).hexdigest())
def __init__(self, url, entity_transform=None):
"""
:params entity_transform: function transforming (e.g. base64 or sha1
:params url: mdx service url
:params entity_transform: function transforming (e.g. base64,
sha1 hash or URL quote
hash) the entity id. It is applied to the entity id before it is
concatenated with the request URL sent to the MDX server.
:params onts:
:params attrc:
:params url:
:params security: SecurityContext()
:params cert:
:params http:
concatenated with the request URL sent to the MDX server. Defaults to
sha1 transformation.
"""
super(MetaDataMDX, self).__init__(onts, attrc, **kwargs)
super(MetaDataMDX, self).__init__(None, None)
self.url = url
self.security = security
self.cert = cert
self.http = http
self.entity_transform = entity_transform
if entity_transform:
self.entity_transform = entity_transform
else:
self.entity_transform = MetaDataMDX.sha1_entity_transform
def load(self):
# Do nothing
pass
def __getitem__(self, item):
@@ -733,13 +752,9 @@ class MetaDataMDX(InMemoryMetaData):
return self.entity[item]
except KeyError:
mdx_url = "%s/entities/%s" % (self.url, self.entity_transform(item))
response = self.http.send(
mdx_url, headers={'Accept': SAML_METADATA_CONTENT_TYPE})
response = requests.get(mdx_url, headers={
'Accept': SAML_METADATA_CONTENT_TYPE})
if response.status_code == 200:
node_name = self.node_name \
or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
md.EntitiesDescriptor.c_tag)
_txt = response.text.encode("utf-8")
if self.parse_and_check_signature(_txt):
@@ -748,6 +763,12 @@ class MetaDataMDX(InMemoryMetaData):
logger.info("Response status: %s", response.status_code)
raise KeyError
def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"):
if binding is None:
binding = BINDING_HTTP_REDIRECT
return self.service(entity_id, "idpsso_descriptor",
"single_sign_on_service", binding)
class MetadataStore(MetaData):
def __init__(self, onts, attrc, config, ca_certs=None,

View File

@@ -58,6 +58,7 @@ from saml2.validate import NotValid
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
@@ -160,9 +161,11 @@ class StatusUnknownPrincipal(StatusError):
class StatusUnsupportedBinding(StatusError):
pass
class StatusResponder(StatusError):
pass
STATUSCODE2EXCEPTION = {
STATUS_VERSION_MISMATCH: StatusVersionMismatch,
STATUS_AUTHN_FAILED: StatusAuthnFailed,
@@ -186,6 +189,8 @@ STATUSCODE2EXCEPTION = {
STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding,
STATUS_RESPONDER: StatusResponder,
}
# ---------------------------------------------------------------------------
@@ -206,7 +211,8 @@ def for_me(conditions, myself):
if audience.text.strip() == myself:
return True
else:
#print("Not for me: %s != %s" % (audience.text.strip(), myself))
# print("Not for me: %s != %s" % (audience.text.strip(),
# myself))
pass
return False
@@ -336,7 +342,7 @@ class StatusResponse(object):
logger.exception("EXCEPTION: %s", excp)
raise
#print("<", self.response)
# print("<", self.response)
return self._postamble()
@@ -377,7 +383,7 @@ class StatusResponse(object):
if self.request_id and self.in_response_to and \
self.in_response_to != self.request_id:
logger.error("Not the id I expected: %s != %s",
self.in_response_to, self.request_id)
self.in_response_to, self.request_id)
return None
try:
@@ -391,9 +397,9 @@ class StatusResponse(object):
if self.asynchop:
if self.response.destination and \
self.response.destination not in self.return_addrs:
self.response.destination not in self.return_addrs:
logger.error("%s not in %s", self.response.destination,
self.return_addrs)
self.return_addrs)
return None
assert self.issue_instant_ok()
@@ -436,7 +442,7 @@ class NameIDMappingResponse(StatusResponse):
request_id=0, asynchop=True):
StatusResponse.__init__(self, sec_context, return_addrs, timeslack,
request_id, asynchop)
self.signature_check = self.sec\
self.signature_check = self.sec \
.correctly_signed_name_id_mapping_response
@@ -506,7 +512,7 @@ class AuthnResponse(StatusResponse):
if self.asynchop:
if self.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[self.in_response_to]
#del self.outstanding_queries[self.in_response_to]
# del self.outstanding_queries[self.in_response_to]
try:
if not self.check_subject_confirmation_in_response_to(
self.in_response_to):
@@ -632,12 +638,12 @@ class AuthnResponse(StatusResponse):
def read_attribute_statement(self, attr_statem):
logger.debug("Attribute Statement: %s", attr_statem)
for aconv in self.attribute_converters:
logger.debug("Converts name format: %s", aconv.name_format)
# for aconv in self.attribute_converters:
# logger.debug("Converts name format: %s", aconv.name_format)
self.decrypt_attributes(attr_statem)
return to_local(self.attribute_converters, attr_statem,
self.allow_unknown_attributes)
self.allow_unknown_attributes)
def get_identity(self):
""" The assertion can contain zero or one attributeStatements
@@ -650,7 +656,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in _assertion.advice.assertion:
if tmp_assertion.attribute_statement:
assert len(tmp_assertion.attribute_statement) == 1
ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
ava.update(self.read_attribute_statement(
tmp_assertion.attribute_statement[0]))
if _assertion.attribute_statement:
assert len(_assertion.attribute_statement) == 1
_attr_statem = _assertion.attribute_statement[0]
@@ -681,7 +688,7 @@ class AuthnResponse(StatusResponse):
if data.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[
data.in_response_to]
#del self.outstanding_queries[data.in_response_to]
# del self.outstanding_queries[data.in_response_to]
elif self.allow_unsolicited:
pass
else:
@@ -690,7 +697,7 @@ class AuthnResponse(StatusResponse):
# recognize
logger.debug("in response to: '%s'", data.in_response_to)
logger.info("outstanding queries: %s",
self.outstanding_queries.keys())
self.outstanding_queries.keys())
raise Exception(
"Combination of session id and requestURI I don't "
"recall")
@@ -768,7 +775,8 @@ class AuthnResponse(StatusResponse):
logger.debug("signed")
if not verified and self.do_not_verify is False:
try:
self.sec.check_signature(assertion, class_name(assertion),self.xmlstr)
self.sec.check_signature(assertion, class_name(assertion),
self.xmlstr)
except Exception as exc:
logger.error("correctly_signed_response: %s", exc)
raise
@@ -778,10 +786,10 @@ class AuthnResponse(StatusResponse):
logger.debug("assertion keys: %s", assertion.keyswv())
logger.debug("outstanding_queries: %s", self.outstanding_queries)
#if self.context == "AuthnReq" or self.context == "AttrQuery":
# if self.context == "AuthnReq" or self.context == "AttrQuery":
if self.context == "AuthnReq":
self.authn_statement_ok()
# elif self.context == "AttrQuery":
# elif self.context == "AttrQuery":
# self.authn_statement_ok(True)
if not self.condition_ok():
@@ -789,7 +797,7 @@ class AuthnResponse(StatusResponse):
logger.debug("--- Getting Identity ---")
#if self.context == "AuthnReq" or self.context == "AttrQuery":
# if self.context == "AuthnReq" or self.context == "AttrQuery":
# self.ava = self.get_identity()
# logger.debug("--- AVA: %s", self.ava)
@@ -805,13 +813,17 @@ class AuthnResponse(StatusResponse):
logger.exception("get subject")
raise
def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False):
""" Moves the decrypted assertion from the encrypted assertion to a list.
def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None,
verified=False):
""" Moves the decrypted assertion from the encrypted assertion to a
list.
:param encrypted_assertions: A list of encrypted assertions.
:param decr_txt: The string representation containing the decrypted data. Used when verifying signatures.
:param decr_txt: The string representation containing the decrypted
data. Used when verifying signatures.
:param issuer: The issuer of the response.
:param verified: If True do not verify signatures, otherwise verify the signature if it exists.
:param verified: If True do not verify signatures, otherwise verify
the signature if it exists.
:return: A list of decrypted assertions.
"""
res = []
@@ -824,7 +836,8 @@ class AuthnResponse(StatusResponse):
if not self.sec.check_signature(
assertion, origdoc=decr_txt,
node_name=class_name(assertion), issuer=issuer):
logger.error("Failed to verify signature on '%s'", assertion)
logger.error("Failed to verify signature on '%s'",
assertion)
raise SignatureError()
res.append(assertion)
return res
@@ -836,11 +849,12 @@ class AuthnResponse(StatusResponse):
:return: True encrypted data exists otherwise false.
"""
for _assertion in enc_assertions:
if _assertion.encrypted_data is not None:
return True
if _assertion.encrypted_data is not None:
return True
def find_encrypt_data_assertion_list(self, _assertions):
""" Verifies if a list of assertions contains encrypted data in the advice element.
""" Verifies if a list of assertions contains encrypted data in the
advice element.
:param _assertions: A list of assertions.
:return: True encrypted data exists otherwise false.
@@ -848,12 +862,14 @@ class AuthnResponse(StatusResponse):
for _assertion in _assertions:
if _assertion.advice:
if _assertion.advice.encrypted_assertion:
res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion)
res = self.find_encrypt_data_assertion(
_assertion.advice.encrypted_assertion)
if res:
return True
def find_encrypt_data(self, resp):
""" Verifies if a saml response contains encrypted assertions with encrypted data.
""" Verifies if a saml response contains encrypted assertions with
encrypted data.
:param resp: A saml response.
:return: True encrypted data exists otherwise false.
@@ -867,7 +883,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in resp.assertion:
if tmp_assertion.advice:
if tmp_assertion.advice.encrypted_assertion:
res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion)
res = self.find_encrypt_data_assertion(
tmp_assertion.advice.encrypted_assertion)
if res:
return True
return False
@@ -875,7 +892,8 @@ class AuthnResponse(StatusResponse):
def parse_assertion(self, keys=None):
""" Parse the assertions for a saml response.
:param keys: A string representing a RSA key or a list of strings containing RSA keys.
:param keys: A string representing a RSA key or a list of strings
containing RSA keys.
:return: True if the assertions are parsed otherwise False.
"""
if self.context == "AuthnQuery":
@@ -884,12 +902,13 @@ class AuthnResponse(StatusResponse):
else: # This is a saml2int limitation
try:
assert len(self.response.assertion) == 1 or \
len(self.response.encrypted_assertion) == 1
len(self.response.encrypted_assertion) == 1
except AssertionError:
raise Exception("No assertion part")
has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion
#if not has_encrypted_assertions and self.response.assertion:
has_encrypted_assertions = self.find_encrypt_data(self.response) #
# self.response.encrypted_assertion
# if not has_encrypted_assertions and self.response.assertion:
# for tmp_assertion in self.response.assertion:
# if tmp_assertion.advice:
# if tmp_assertion.advice.encrypted_assertion:
@@ -912,15 +931,20 @@ class AuthnResponse(StatusResponse):
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion,
decr_text)
decr_text_old = None
while (self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions)) and \
while (self.find_encrypt_data(
resp) or self.find_encrypt_data_assertion_list(
_enc_assertions)) and \
decr_text_old != decr_text:
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
#_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
_enc_assertions = self.decrypt_assertions(
resp.encrypted_assertion, decr_text, verified=True)
# _enc_assertions = self.decrypt_assertions(
# resp.encrypted_assertion, decr_text, verified=True)
all_assertions = _enc_assertions
if resp.assertion:
all_assertions = all_assertions + resp.assertion
@@ -928,9 +952,10 @@ class AuthnResponse(StatusResponse):
for tmp_ass in all_assertions:
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion,
decr_text,
tmp_ass.issuer)
advice_res = self.decrypt_assertions(
tmp_ass.advice.encrypted_assertion,
decr_text,
tmp_ass.issuer)
if tmp_ass.advice.assertion:
tmp_ass.advice.assertion.extend(advice_res)
else:
@@ -1211,7 +1236,7 @@ class AssertionIDResponse(object):
logger.exception("EXCEPTION: %s", excp)
raise
#print("<", self.response)
# print("<", self.response)
return self._postamble()
@@ -1233,4 +1258,3 @@ class AssertionIDResponse(object):
logger.debug("response: %s", self.response)
return self

View File

@@ -2,11 +2,14 @@
# -*- coding: utf-8 -*-
import datetime
import re
from six.moves.urllib.parse import quote_plus
#from six.moves.urllib.parse import quote_plus
from future.backports.urllib.parse import quote_plus
from saml2.config import Config
from saml2.httpbase import HTTPBase
from saml2.mdstore import MetadataStore, MetaDataMDX
from saml2.mdstore import MetadataStore
from saml2.mdstore import MetaDataMDX
from saml2.mdstore import SAML_METADATA_CONTENT_TYPE
from saml2.mdstore import destinations
from saml2.mdstore import load_extensions
from saml2.mdstore import name
from saml2 import md
from saml2 import sigver
@@ -18,16 +21,13 @@ from saml2 import saml
from saml2 import config
from saml2.attribute_converter import ac_factory
from saml2.attribute_converter import d_to_local_name
from saml2.extension import mdui
from saml2.extension import idpdisc
from saml2.extension import dri
from saml2.extension import mdattr
from saml2.extension import ui
from saml2.s_utils import UnknownPrincipal
from saml2 import xmldsig
from saml2 import xmlenc
from pathutils import full_path
import responses
sec_config = config.Config()
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
@@ -88,16 +88,13 @@ TEST_METADATA_STRING = """
ONTS = {
saml.NAMESPACE: saml,
mdui.NAMESPACE: mdui,
mdattr.NAMESPACE: mdattr,
dri.NAMESPACE: dri,
ui.NAMESPACE: ui,
idpdisc.NAMESPACE: idpdisc,
md.NAMESPACE: md,
xmldsig.NAMESPACE: xmldsig,
xmlenc.NAMESPACE: xmlenc
}
ONTS.update(load_extensions())
ATTRCONV = ac_factory(full_path("attributemaps"))
METADATACONF = {
@@ -150,6 +147,10 @@ METADATACONF = {
"class": "saml2.mdstore.InMemoryMetaData",
"metadata": [(TEST_METADATA_STRING,)]
}],
"12": [{
"class": "saml2.mdstore.MetaDataFile",
"metadata": [(full_path("uu.xml"),)],
}],
}
@@ -303,6 +304,36 @@ def test_metadata_file():
assert len(mds.keys()) == 560
@responses.activate
def test_mdx_service():
entity_id = "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
url = "http://mdx.example.com/entities/{}".format(
quote_plus(MetaDataMDX.sha1_entity_transform(entity_id)))
responses.add(responses.GET, url, body=TEST_METADATA_STRING, status=200,
content_type=SAML_METADATA_CONTENT_TYPE)
mdx = MetaDataMDX("http://mdx.example.com")
sso_loc = mdx.service(entity_id, "idpsso_descriptor", "single_sign_on_service")
assert sso_loc[BINDING_HTTP_REDIRECT][0]["location"] == "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
certs = mdx.certs(entity_id, "idpsso")
assert len(certs) == 1
@responses.activate
def test_mdx_single_sign_on_service():
entity_id = "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
url = "http://mdx.example.com/entities/{}".format(
quote_plus(MetaDataMDX.sha1_entity_transform(entity_id)))
responses.add(responses.GET, url, body=TEST_METADATA_STRING, status=200,
content_type=SAML_METADATA_CONTENT_TYPE)
mdx = MetaDataMDX("http://mdx.example.com")
sso_loc = mdx.single_sign_on_service(entity_id, BINDING_HTTP_REDIRECT)
assert sso_loc[0]["location"] == "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
# pyff-test not available
# def test_mdx_service():
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
@@ -429,6 +460,12 @@ def test_get_certs_from_metadata_without_keydescriptor():
assert len(certs) == 0
def test_metadata_extension_algsupport():
mds = MetadataStore(list(ONTS.values()), ATTRCONV, None)
mds.imp(METADATACONF["12"])
mdf = mds.metadata[full_path("uu.xml")]
_txt = mdf.dumps()
assert mds
if __name__ == "__main__":
test_get_certs_from_metadata()
test_metadata_extension_algsupport()

View File

@@ -0,0 +1,2 @@
pymongo==3.0.1
responses==0.5.0

View File

@@ -3,5 +3,5 @@ envlist = py27,py34
[testenv]
deps = pytest
pymongo==3.0.1
-rtests/test_requirements.txt
commands = py.test tests/