Fixed a problem in parsing metadata extensions.
This commit is contained in:
@@ -979,7 +979,7 @@ def extension_elements_to_elements(extension_elements, schemas):
|
|||||||
if isinstance(schemas, list):
|
if isinstance(schemas, list):
|
||||||
pass
|
pass
|
||||||
elif isinstance(schemas, dict):
|
elif isinstance(schemas, dict):
|
||||||
schemas = schemas.values()
|
schemas = list(schemas.values())
|
||||||
else:
|
else:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,20 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import six
|
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import six
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from os.path import isfile, join
|
from os.path import isfile, join
|
||||||
from saml2.httpbase import HTTPBase
|
from saml2.httpbase import HTTPBase
|
||||||
from saml2.extension.idpdisc import BINDING_DISCO
|
from saml2.extension.idpdisc import BINDING_DISCO
|
||||||
from saml2.extension.idpdisc import DiscoveryResponse
|
from saml2.extension.idpdisc import DiscoveryResponse
|
||||||
from saml2.md import EntitiesDescriptor
|
from saml2.md import EntitiesDescriptor
|
||||||
|
|
||||||
from saml2.mdie import to_dict
|
from saml2.mdie import to_dict
|
||||||
|
|
||||||
from saml2 import md
|
from saml2 import md
|
||||||
from saml2 import samlp
|
from saml2 import samlp
|
||||||
from saml2 import SAMLError
|
from saml2 import SAMLError
|
||||||
@@ -67,6 +68,20 @@ ENTITY_CATEGORY_SUPPORT = "http://macedir.org/entity-category-support"
|
|||||||
|
|
||||||
# ---------------------------------------------------
|
# ---------------------------------------------------
|
||||||
|
|
||||||
|
def load_extensions():
|
||||||
|
from saml2 import extension
|
||||||
|
import pkgutil
|
||||||
|
|
||||||
|
package = extension
|
||||||
|
prefix = package.__name__ + "."
|
||||||
|
ext_map = {}
|
||||||
|
for importer, modname, ispkg in pkgutil.iter_modules(package.__path__,
|
||||||
|
prefix):
|
||||||
|
module = __import__(modname, fromlist="dummy")
|
||||||
|
ext_map[module.NAMESPACE] = module
|
||||||
|
|
||||||
|
return ext_map
|
||||||
|
|
||||||
|
|
||||||
def destinations(srvs):
|
def destinations(srvs):
|
||||||
return [s["location"] for s in srvs]
|
return [s["location"] for s in srvs]
|
||||||
@@ -564,8 +579,8 @@ class InMemoryMetaData(MetaData):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
node_name = self.node_name \
|
node_name = self.node_name \
|
||||||
or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
|
or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
|
||||||
md.EntitiesDescriptor.c_tag)
|
md.EntitiesDescriptor.c_tag)
|
||||||
|
|
||||||
if self.security.verify_signature(
|
if self.security.verify_signature(
|
||||||
txt, node_name=node_name, cert_file=self.cert):
|
txt, node_name=node_name, cert_file=self.cert):
|
||||||
@@ -705,27 +720,31 @@ class MetaDataMDX(InMemoryMetaData):
|
|||||||
""" Uses the md protocol to fetch entity information
|
""" Uses the md protocol to fetch entity information
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, entity_transform, onts, attrc, url, security, cert,
|
@staticmethod
|
||||||
http, **kwargs):
|
def sha1_entity_transform(entity_id):
|
||||||
|
return "{{sha1}}{}".format(
|
||||||
|
hashlib.sha1(entity_id.encode("utf-8")).hexdigest())
|
||||||
|
|
||||||
|
def __init__(self, url, entity_transform=None):
|
||||||
"""
|
"""
|
||||||
:params entity_transform: function transforming (e.g. base64 or sha1
|
:params url: mdx service url
|
||||||
|
:params entity_transform: function transforming (e.g. base64,
|
||||||
|
sha1 hash or URL quote
|
||||||
hash) the entity id. It is applied to the entity id before it is
|
hash) the entity id. It is applied to the entity id before it is
|
||||||
concatenated with the request URL sent to the MDX server.
|
concatenated with the request URL sent to the MDX server. Defaults to
|
||||||
:params onts:
|
sha1 transformation.
|
||||||
:params attrc:
|
|
||||||
:params url:
|
|
||||||
:params security: SecurityContext()
|
|
||||||
:params cert:
|
|
||||||
:params http:
|
|
||||||
"""
|
"""
|
||||||
super(MetaDataMDX, self).__init__(onts, attrc, **kwargs)
|
super(MetaDataMDX, self).__init__(None, None)
|
||||||
self.url = url
|
self.url = url
|
||||||
self.security = security
|
|
||||||
self.cert = cert
|
if entity_transform:
|
||||||
self.http = http
|
self.entity_transform = entity_transform
|
||||||
self.entity_transform = entity_transform
|
else:
|
||||||
|
|
||||||
|
self.entity_transform = MetaDataMDX.sha1_entity_transform
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
# Do nothing
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
@@ -733,13 +752,9 @@ class MetaDataMDX(InMemoryMetaData):
|
|||||||
return self.entity[item]
|
return self.entity[item]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
mdx_url = "%s/entities/%s" % (self.url, self.entity_transform(item))
|
mdx_url = "%s/entities/%s" % (self.url, self.entity_transform(item))
|
||||||
response = self.http.send(
|
response = requests.get(mdx_url, headers={
|
||||||
mdx_url, headers={'Accept': SAML_METADATA_CONTENT_TYPE})
|
'Accept': SAML_METADATA_CONTENT_TYPE})
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
node_name = self.node_name \
|
|
||||||
or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
|
|
||||||
md.EntitiesDescriptor.c_tag)
|
|
||||||
|
|
||||||
_txt = response.text.encode("utf-8")
|
_txt = response.text.encode("utf-8")
|
||||||
|
|
||||||
if self.parse_and_check_signature(_txt):
|
if self.parse_and_check_signature(_txt):
|
||||||
@@ -748,6 +763,12 @@ class MetaDataMDX(InMemoryMetaData):
|
|||||||
logger.info("Response status: %s", response.status_code)
|
logger.info("Response status: %s", response.status_code)
|
||||||
raise KeyError
|
raise KeyError
|
||||||
|
|
||||||
|
def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"):
|
||||||
|
if binding is None:
|
||||||
|
binding = BINDING_HTTP_REDIRECT
|
||||||
|
return self.service(entity_id, "idpsso_descriptor",
|
||||||
|
"single_sign_on_service", binding)
|
||||||
|
|
||||||
|
|
||||||
class MetadataStore(MetaData):
|
class MetadataStore(MetaData):
|
||||||
def __init__(self, onts, attrc, config, ca_certs=None,
|
def __init__(self, onts, attrc, config, ca_certs=None,
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ from saml2.validate import NotValid
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -160,9 +161,11 @@ class StatusUnknownPrincipal(StatusError):
|
|||||||
class StatusUnsupportedBinding(StatusError):
|
class StatusUnsupportedBinding(StatusError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class StatusResponder(StatusError):
|
class StatusResponder(StatusError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
STATUSCODE2EXCEPTION = {
|
STATUSCODE2EXCEPTION = {
|
||||||
STATUS_VERSION_MISMATCH: StatusVersionMismatch,
|
STATUS_VERSION_MISMATCH: StatusVersionMismatch,
|
||||||
STATUS_AUTHN_FAILED: StatusAuthnFailed,
|
STATUS_AUTHN_FAILED: StatusAuthnFailed,
|
||||||
@@ -186,6 +189,8 @@ STATUSCODE2EXCEPTION = {
|
|||||||
STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding,
|
STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding,
|
||||||
STATUS_RESPONDER: StatusResponder,
|
STATUS_RESPONDER: StatusResponder,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -206,7 +211,8 @@ def for_me(conditions, myself):
|
|||||||
if audience.text.strip() == myself:
|
if audience.text.strip() == myself:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
#print("Not for me: %s != %s" % (audience.text.strip(), myself))
|
# print("Not for me: %s != %s" % (audience.text.strip(),
|
||||||
|
# myself))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -336,7 +342,7 @@ class StatusResponse(object):
|
|||||||
logger.exception("EXCEPTION: %s", excp)
|
logger.exception("EXCEPTION: %s", excp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
#print("<", self.response)
|
# print("<", self.response)
|
||||||
|
|
||||||
return self._postamble()
|
return self._postamble()
|
||||||
|
|
||||||
@@ -377,7 +383,7 @@ class StatusResponse(object):
|
|||||||
if self.request_id and self.in_response_to and \
|
if self.request_id and self.in_response_to and \
|
||||||
self.in_response_to != self.request_id:
|
self.in_response_to != self.request_id:
|
||||||
logger.error("Not the id I expected: %s != %s",
|
logger.error("Not the id I expected: %s != %s",
|
||||||
self.in_response_to, self.request_id)
|
self.in_response_to, self.request_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -391,9 +397,9 @@ class StatusResponse(object):
|
|||||||
|
|
||||||
if self.asynchop:
|
if self.asynchop:
|
||||||
if self.response.destination and \
|
if self.response.destination and \
|
||||||
self.response.destination not in self.return_addrs:
|
self.response.destination not in self.return_addrs:
|
||||||
logger.error("%s not in %s", self.response.destination,
|
logger.error("%s not in %s", self.response.destination,
|
||||||
self.return_addrs)
|
self.return_addrs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
assert self.issue_instant_ok()
|
assert self.issue_instant_ok()
|
||||||
@@ -436,7 +442,7 @@ class NameIDMappingResponse(StatusResponse):
|
|||||||
request_id=0, asynchop=True):
|
request_id=0, asynchop=True):
|
||||||
StatusResponse.__init__(self, sec_context, return_addrs, timeslack,
|
StatusResponse.__init__(self, sec_context, return_addrs, timeslack,
|
||||||
request_id, asynchop)
|
request_id, asynchop)
|
||||||
self.signature_check = self.sec\
|
self.signature_check = self.sec \
|
||||||
.correctly_signed_name_id_mapping_response
|
.correctly_signed_name_id_mapping_response
|
||||||
|
|
||||||
|
|
||||||
@@ -506,7 +512,7 @@ class AuthnResponse(StatusResponse):
|
|||||||
if self.asynchop:
|
if self.asynchop:
|
||||||
if self.in_response_to in self.outstanding_queries:
|
if self.in_response_to in self.outstanding_queries:
|
||||||
self.came_from = self.outstanding_queries[self.in_response_to]
|
self.came_from = self.outstanding_queries[self.in_response_to]
|
||||||
#del self.outstanding_queries[self.in_response_to]
|
# del self.outstanding_queries[self.in_response_to]
|
||||||
try:
|
try:
|
||||||
if not self.check_subject_confirmation_in_response_to(
|
if not self.check_subject_confirmation_in_response_to(
|
||||||
self.in_response_to):
|
self.in_response_to):
|
||||||
@@ -632,12 +638,12 @@ class AuthnResponse(StatusResponse):
|
|||||||
|
|
||||||
def read_attribute_statement(self, attr_statem):
|
def read_attribute_statement(self, attr_statem):
|
||||||
logger.debug("Attribute Statement: %s", attr_statem)
|
logger.debug("Attribute Statement: %s", attr_statem)
|
||||||
for aconv in self.attribute_converters:
|
# for aconv in self.attribute_converters:
|
||||||
logger.debug("Converts name format: %s", aconv.name_format)
|
# logger.debug("Converts name format: %s", aconv.name_format)
|
||||||
|
|
||||||
self.decrypt_attributes(attr_statem)
|
self.decrypt_attributes(attr_statem)
|
||||||
return to_local(self.attribute_converters, attr_statem,
|
return to_local(self.attribute_converters, attr_statem,
|
||||||
self.allow_unknown_attributes)
|
self.allow_unknown_attributes)
|
||||||
|
|
||||||
def get_identity(self):
|
def get_identity(self):
|
||||||
""" The assertion can contain zero or one attributeStatements
|
""" The assertion can contain zero or one attributeStatements
|
||||||
@@ -650,7 +656,8 @@ class AuthnResponse(StatusResponse):
|
|||||||
for tmp_assertion in _assertion.advice.assertion:
|
for tmp_assertion in _assertion.advice.assertion:
|
||||||
if tmp_assertion.attribute_statement:
|
if tmp_assertion.attribute_statement:
|
||||||
assert len(tmp_assertion.attribute_statement) == 1
|
assert len(tmp_assertion.attribute_statement) == 1
|
||||||
ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
|
ava.update(self.read_attribute_statement(
|
||||||
|
tmp_assertion.attribute_statement[0]))
|
||||||
if _assertion.attribute_statement:
|
if _assertion.attribute_statement:
|
||||||
assert len(_assertion.attribute_statement) == 1
|
assert len(_assertion.attribute_statement) == 1
|
||||||
_attr_statem = _assertion.attribute_statement[0]
|
_attr_statem = _assertion.attribute_statement[0]
|
||||||
@@ -681,7 +688,7 @@ class AuthnResponse(StatusResponse):
|
|||||||
if data.in_response_to in self.outstanding_queries:
|
if data.in_response_to in self.outstanding_queries:
|
||||||
self.came_from = self.outstanding_queries[
|
self.came_from = self.outstanding_queries[
|
||||||
data.in_response_to]
|
data.in_response_to]
|
||||||
#del self.outstanding_queries[data.in_response_to]
|
# del self.outstanding_queries[data.in_response_to]
|
||||||
elif self.allow_unsolicited:
|
elif self.allow_unsolicited:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@@ -690,7 +697,7 @@ class AuthnResponse(StatusResponse):
|
|||||||
# recognize
|
# recognize
|
||||||
logger.debug("in response to: '%s'", data.in_response_to)
|
logger.debug("in response to: '%s'", data.in_response_to)
|
||||||
logger.info("outstanding queries: %s",
|
logger.info("outstanding queries: %s",
|
||||||
self.outstanding_queries.keys())
|
self.outstanding_queries.keys())
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Combination of session id and requestURI I don't "
|
"Combination of session id and requestURI I don't "
|
||||||
"recall")
|
"recall")
|
||||||
@@ -768,7 +775,8 @@ class AuthnResponse(StatusResponse):
|
|||||||
logger.debug("signed")
|
logger.debug("signed")
|
||||||
if not verified and self.do_not_verify is False:
|
if not verified and self.do_not_verify is False:
|
||||||
try:
|
try:
|
||||||
self.sec.check_signature(assertion, class_name(assertion),self.xmlstr)
|
self.sec.check_signature(assertion, class_name(assertion),
|
||||||
|
self.xmlstr)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("correctly_signed_response: %s", exc)
|
logger.error("correctly_signed_response: %s", exc)
|
||||||
raise
|
raise
|
||||||
@@ -778,10 +786,10 @@ class AuthnResponse(StatusResponse):
|
|||||||
logger.debug("assertion keys: %s", assertion.keyswv())
|
logger.debug("assertion keys: %s", assertion.keyswv())
|
||||||
logger.debug("outstanding_queries: %s", self.outstanding_queries)
|
logger.debug("outstanding_queries: %s", self.outstanding_queries)
|
||||||
|
|
||||||
#if self.context == "AuthnReq" or self.context == "AttrQuery":
|
# if self.context == "AuthnReq" or self.context == "AttrQuery":
|
||||||
if self.context == "AuthnReq":
|
if self.context == "AuthnReq":
|
||||||
self.authn_statement_ok()
|
self.authn_statement_ok()
|
||||||
# elif self.context == "AttrQuery":
|
# elif self.context == "AttrQuery":
|
||||||
# self.authn_statement_ok(True)
|
# self.authn_statement_ok(True)
|
||||||
|
|
||||||
if not self.condition_ok():
|
if not self.condition_ok():
|
||||||
@@ -789,7 +797,7 @@ class AuthnResponse(StatusResponse):
|
|||||||
|
|
||||||
logger.debug("--- Getting Identity ---")
|
logger.debug("--- Getting Identity ---")
|
||||||
|
|
||||||
#if self.context == "AuthnReq" or self.context == "AttrQuery":
|
# if self.context == "AuthnReq" or self.context == "AttrQuery":
|
||||||
# self.ava = self.get_identity()
|
# self.ava = self.get_identity()
|
||||||
# logger.debug("--- AVA: %s", self.ava)
|
# logger.debug("--- AVA: %s", self.ava)
|
||||||
|
|
||||||
@@ -805,13 +813,17 @@ class AuthnResponse(StatusResponse):
|
|||||||
logger.exception("get subject")
|
logger.exception("get subject")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False):
|
def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None,
|
||||||
""" Moves the decrypted assertion from the encrypted assertion to a list.
|
verified=False):
|
||||||
|
""" Moves the decrypted assertion from the encrypted assertion to a
|
||||||
|
list.
|
||||||
|
|
||||||
:param encrypted_assertions: A list of encrypted assertions.
|
:param encrypted_assertions: A list of encrypted assertions.
|
||||||
:param decr_txt: The string representation containing the decrypted data. Used when verifying signatures.
|
:param decr_txt: The string representation containing the decrypted
|
||||||
|
data. Used when verifying signatures.
|
||||||
:param issuer: The issuer of the response.
|
:param issuer: The issuer of the response.
|
||||||
:param verified: If True do not verify signatures, otherwise verify the signature if it exists.
|
:param verified: If True do not verify signatures, otherwise verify
|
||||||
|
the signature if it exists.
|
||||||
:return: A list of decrypted assertions.
|
:return: A list of decrypted assertions.
|
||||||
"""
|
"""
|
||||||
res = []
|
res = []
|
||||||
@@ -824,7 +836,8 @@ class AuthnResponse(StatusResponse):
|
|||||||
if not self.sec.check_signature(
|
if not self.sec.check_signature(
|
||||||
assertion, origdoc=decr_txt,
|
assertion, origdoc=decr_txt,
|
||||||
node_name=class_name(assertion), issuer=issuer):
|
node_name=class_name(assertion), issuer=issuer):
|
||||||
logger.error("Failed to verify signature on '%s'", assertion)
|
logger.error("Failed to verify signature on '%s'",
|
||||||
|
assertion)
|
||||||
raise SignatureError()
|
raise SignatureError()
|
||||||
res.append(assertion)
|
res.append(assertion)
|
||||||
return res
|
return res
|
||||||
@@ -836,11 +849,12 @@ class AuthnResponse(StatusResponse):
|
|||||||
:return: True encrypted data exists otherwise false.
|
:return: True encrypted data exists otherwise false.
|
||||||
"""
|
"""
|
||||||
for _assertion in enc_assertions:
|
for _assertion in enc_assertions:
|
||||||
if _assertion.encrypted_data is not None:
|
if _assertion.encrypted_data is not None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def find_encrypt_data_assertion_list(self, _assertions):
|
def find_encrypt_data_assertion_list(self, _assertions):
|
||||||
""" Verifies if a list of assertions contains encrypted data in the advice element.
|
""" Verifies if a list of assertions contains encrypted data in the
|
||||||
|
advice element.
|
||||||
|
|
||||||
:param _assertions: A list of assertions.
|
:param _assertions: A list of assertions.
|
||||||
:return: True encrypted data exists otherwise false.
|
:return: True encrypted data exists otherwise false.
|
||||||
@@ -848,12 +862,14 @@ class AuthnResponse(StatusResponse):
|
|||||||
for _assertion in _assertions:
|
for _assertion in _assertions:
|
||||||
if _assertion.advice:
|
if _assertion.advice:
|
||||||
if _assertion.advice.encrypted_assertion:
|
if _assertion.advice.encrypted_assertion:
|
||||||
res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion)
|
res = self.find_encrypt_data_assertion(
|
||||||
|
_assertion.advice.encrypted_assertion)
|
||||||
if res:
|
if res:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def find_encrypt_data(self, resp):
|
def find_encrypt_data(self, resp):
|
||||||
""" Verifies if a saml response contains encrypted assertions with encrypted data.
|
""" Verifies if a saml response contains encrypted assertions with
|
||||||
|
encrypted data.
|
||||||
|
|
||||||
:param resp: A saml response.
|
:param resp: A saml response.
|
||||||
:return: True encrypted data exists otherwise false.
|
:return: True encrypted data exists otherwise false.
|
||||||
@@ -867,7 +883,8 @@ class AuthnResponse(StatusResponse):
|
|||||||
for tmp_assertion in resp.assertion:
|
for tmp_assertion in resp.assertion:
|
||||||
if tmp_assertion.advice:
|
if tmp_assertion.advice:
|
||||||
if tmp_assertion.advice.encrypted_assertion:
|
if tmp_assertion.advice.encrypted_assertion:
|
||||||
res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion)
|
res = self.find_encrypt_data_assertion(
|
||||||
|
tmp_assertion.advice.encrypted_assertion)
|
||||||
if res:
|
if res:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@@ -875,7 +892,8 @@ class AuthnResponse(StatusResponse):
|
|||||||
def parse_assertion(self, keys=None):
|
def parse_assertion(self, keys=None):
|
||||||
""" Parse the assertions for a saml response.
|
""" Parse the assertions for a saml response.
|
||||||
|
|
||||||
:param keys: A string representing a RSA key or a list of strings containing RSA keys.
|
:param keys: A string representing a RSA key or a list of strings
|
||||||
|
containing RSA keys.
|
||||||
:return: True if the assertions are parsed otherwise False.
|
:return: True if the assertions are parsed otherwise False.
|
||||||
"""
|
"""
|
||||||
if self.context == "AuthnQuery":
|
if self.context == "AuthnQuery":
|
||||||
@@ -884,12 +902,13 @@ class AuthnResponse(StatusResponse):
|
|||||||
else: # This is a saml2int limitation
|
else: # This is a saml2int limitation
|
||||||
try:
|
try:
|
||||||
assert len(self.response.assertion) == 1 or \
|
assert len(self.response.assertion) == 1 or \
|
||||||
len(self.response.encrypted_assertion) == 1
|
len(self.response.encrypted_assertion) == 1
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
raise Exception("No assertion part")
|
raise Exception("No assertion part")
|
||||||
|
|
||||||
has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion
|
has_encrypted_assertions = self.find_encrypt_data(self.response) #
|
||||||
#if not has_encrypted_assertions and self.response.assertion:
|
# self.response.encrypted_assertion
|
||||||
|
# if not has_encrypted_assertions and self.response.assertion:
|
||||||
# for tmp_assertion in self.response.assertion:
|
# for tmp_assertion in self.response.assertion:
|
||||||
# if tmp_assertion.advice:
|
# if tmp_assertion.advice:
|
||||||
# if tmp_assertion.advice.encrypted_assertion:
|
# if tmp_assertion.advice.encrypted_assertion:
|
||||||
@@ -912,15 +931,20 @@ class AuthnResponse(StatusResponse):
|
|||||||
decr_text_old = decr_text
|
decr_text_old = decr_text
|
||||||
decr_text = self.sec.decrypt_keys(decr_text, keys)
|
decr_text = self.sec.decrypt_keys(decr_text, keys)
|
||||||
resp = samlp.response_from_string(decr_text)
|
resp = samlp.response_from_string(decr_text)
|
||||||
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text)
|
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion,
|
||||||
|
decr_text)
|
||||||
decr_text_old = None
|
decr_text_old = None
|
||||||
while (self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions)) and \
|
while (self.find_encrypt_data(
|
||||||
|
resp) or self.find_encrypt_data_assertion_list(
|
||||||
|
_enc_assertions)) and \
|
||||||
decr_text_old != decr_text:
|
decr_text_old != decr_text:
|
||||||
decr_text_old = decr_text
|
decr_text_old = decr_text
|
||||||
decr_text = self.sec.decrypt_keys(decr_text, keys)
|
decr_text = self.sec.decrypt_keys(decr_text, keys)
|
||||||
resp = samlp.response_from_string(decr_text)
|
resp = samlp.response_from_string(decr_text)
|
||||||
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
|
_enc_assertions = self.decrypt_assertions(
|
||||||
#_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
|
resp.encrypted_assertion, decr_text, verified=True)
|
||||||
|
# _enc_assertions = self.decrypt_assertions(
|
||||||
|
# resp.encrypted_assertion, decr_text, verified=True)
|
||||||
all_assertions = _enc_assertions
|
all_assertions = _enc_assertions
|
||||||
if resp.assertion:
|
if resp.assertion:
|
||||||
all_assertions = all_assertions + resp.assertion
|
all_assertions = all_assertions + resp.assertion
|
||||||
@@ -928,9 +952,10 @@ class AuthnResponse(StatusResponse):
|
|||||||
for tmp_ass in all_assertions:
|
for tmp_ass in all_assertions:
|
||||||
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
|
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
|
||||||
|
|
||||||
advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion,
|
advice_res = self.decrypt_assertions(
|
||||||
decr_text,
|
tmp_ass.advice.encrypted_assertion,
|
||||||
tmp_ass.issuer)
|
decr_text,
|
||||||
|
tmp_ass.issuer)
|
||||||
if tmp_ass.advice.assertion:
|
if tmp_ass.advice.assertion:
|
||||||
tmp_ass.advice.assertion.extend(advice_res)
|
tmp_ass.advice.assertion.extend(advice_res)
|
||||||
else:
|
else:
|
||||||
@@ -1211,7 +1236,7 @@ class AssertionIDResponse(object):
|
|||||||
logger.exception("EXCEPTION: %s", excp)
|
logger.exception("EXCEPTION: %s", excp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
#print("<", self.response)
|
# print("<", self.response)
|
||||||
|
|
||||||
return self._postamble()
|
return self._postamble()
|
||||||
|
|
||||||
@@ -1233,4 +1258,3 @@ class AssertionIDResponse(object):
|
|||||||
logger.debug("response: %s", self.response)
|
logger.debug("response: %s", self.response)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,14 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
from six.moves.urllib.parse import quote_plus
|
#from six.moves.urllib.parse import quote_plus
|
||||||
|
from future.backports.urllib.parse import quote_plus
|
||||||
from saml2.config import Config
|
from saml2.config import Config
|
||||||
from saml2.httpbase import HTTPBase
|
from saml2.mdstore import MetadataStore
|
||||||
from saml2.mdstore import MetadataStore, MetaDataMDX
|
from saml2.mdstore import MetaDataMDX
|
||||||
|
from saml2.mdstore import SAML_METADATA_CONTENT_TYPE
|
||||||
from saml2.mdstore import destinations
|
from saml2.mdstore import destinations
|
||||||
|
from saml2.mdstore import load_extensions
|
||||||
from saml2.mdstore import name
|
from saml2.mdstore import name
|
||||||
from saml2 import md
|
from saml2 import md
|
||||||
from saml2 import sigver
|
from saml2 import sigver
|
||||||
@@ -18,16 +21,13 @@ from saml2 import saml
|
|||||||
from saml2 import config
|
from saml2 import config
|
||||||
from saml2.attribute_converter import ac_factory
|
from saml2.attribute_converter import ac_factory
|
||||||
from saml2.attribute_converter import d_to_local_name
|
from saml2.attribute_converter import d_to_local_name
|
||||||
from saml2.extension import mdui
|
|
||||||
from saml2.extension import idpdisc
|
|
||||||
from saml2.extension import dri
|
|
||||||
from saml2.extension import mdattr
|
|
||||||
from saml2.extension import ui
|
|
||||||
from saml2.s_utils import UnknownPrincipal
|
from saml2.s_utils import UnknownPrincipal
|
||||||
from saml2 import xmldsig
|
from saml2 import xmldsig
|
||||||
from saml2 import xmlenc
|
from saml2 import xmlenc
|
||||||
from pathutils import full_path
|
from pathutils import full_path
|
||||||
|
|
||||||
|
import responses
|
||||||
|
|
||||||
sec_config = config.Config()
|
sec_config = config.Config()
|
||||||
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
|
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
|
||||||
|
|
||||||
@@ -88,16 +88,13 @@ TEST_METADATA_STRING = """
|
|||||||
|
|
||||||
ONTS = {
|
ONTS = {
|
||||||
saml.NAMESPACE: saml,
|
saml.NAMESPACE: saml,
|
||||||
mdui.NAMESPACE: mdui,
|
|
||||||
mdattr.NAMESPACE: mdattr,
|
|
||||||
dri.NAMESPACE: dri,
|
|
||||||
ui.NAMESPACE: ui,
|
|
||||||
idpdisc.NAMESPACE: idpdisc,
|
|
||||||
md.NAMESPACE: md,
|
md.NAMESPACE: md,
|
||||||
xmldsig.NAMESPACE: xmldsig,
|
xmldsig.NAMESPACE: xmldsig,
|
||||||
xmlenc.NAMESPACE: xmlenc
|
xmlenc.NAMESPACE: xmlenc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ONTS.update(load_extensions())
|
||||||
|
|
||||||
ATTRCONV = ac_factory(full_path("attributemaps"))
|
ATTRCONV = ac_factory(full_path("attributemaps"))
|
||||||
|
|
||||||
METADATACONF = {
|
METADATACONF = {
|
||||||
@@ -150,6 +147,10 @@ METADATACONF = {
|
|||||||
"class": "saml2.mdstore.InMemoryMetaData",
|
"class": "saml2.mdstore.InMemoryMetaData",
|
||||||
"metadata": [(TEST_METADATA_STRING,)]
|
"metadata": [(TEST_METADATA_STRING,)]
|
||||||
}],
|
}],
|
||||||
|
"12": [{
|
||||||
|
"class": "saml2.mdstore.MetaDataFile",
|
||||||
|
"metadata": [(full_path("uu.xml"),)],
|
||||||
|
}],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -303,6 +304,36 @@ def test_metadata_file():
|
|||||||
assert len(mds.keys()) == 560
|
assert len(mds.keys()) == 560
|
||||||
|
|
||||||
|
|
||||||
|
@responses.activate
|
||||||
|
def test_mdx_service():
|
||||||
|
entity_id = "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
|
||||||
|
|
||||||
|
url = "http://mdx.example.com/entities/{}".format(
|
||||||
|
quote_plus(MetaDataMDX.sha1_entity_transform(entity_id)))
|
||||||
|
responses.add(responses.GET, url, body=TEST_METADATA_STRING, status=200,
|
||||||
|
content_type=SAML_METADATA_CONTENT_TYPE)
|
||||||
|
|
||||||
|
mdx = MetaDataMDX("http://mdx.example.com")
|
||||||
|
sso_loc = mdx.service(entity_id, "idpsso_descriptor", "single_sign_on_service")
|
||||||
|
assert sso_loc[BINDING_HTTP_REDIRECT][0]["location"] == "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
|
||||||
|
certs = mdx.certs(entity_id, "idpsso")
|
||||||
|
assert len(certs) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@responses.activate
|
||||||
|
def test_mdx_single_sign_on_service():
|
||||||
|
entity_id = "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
|
||||||
|
|
||||||
|
url = "http://mdx.example.com/entities/{}".format(
|
||||||
|
quote_plus(MetaDataMDX.sha1_entity_transform(entity_id)))
|
||||||
|
responses.add(responses.GET, url, body=TEST_METADATA_STRING, status=200,
|
||||||
|
content_type=SAML_METADATA_CONTENT_TYPE)
|
||||||
|
|
||||||
|
mdx = MetaDataMDX("http://mdx.example.com")
|
||||||
|
sso_loc = mdx.single_sign_on_service(entity_id, BINDING_HTTP_REDIRECT)
|
||||||
|
assert sso_loc[0]["location"] == "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
|
||||||
|
|
||||||
|
|
||||||
# pyff-test not available
|
# pyff-test not available
|
||||||
# def test_mdx_service():
|
# def test_mdx_service():
|
||||||
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
|
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
|
||||||
@@ -429,6 +460,12 @@ def test_get_certs_from_metadata_without_keydescriptor():
|
|||||||
|
|
||||||
assert len(certs) == 0
|
assert len(certs) == 0
|
||||||
|
|
||||||
|
def test_metadata_extension_algsupport():
|
||||||
|
mds = MetadataStore(list(ONTS.values()), ATTRCONV, None)
|
||||||
|
mds.imp(METADATACONF["12"])
|
||||||
|
mdf = mds.metadata[full_path("uu.xml")]
|
||||||
|
_txt = mdf.dumps()
|
||||||
|
assert mds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_get_certs_from_metadata()
|
test_metadata_extension_algsupport()
|
||||||
|
|||||||
2
tests/test_requirements.txt
Normal file
2
tests/test_requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
pymongo==3.0.1
|
||||||
|
responses==0.5.0
|
||||||
Reference in New Issue
Block a user