Fixed a problem in parsing metadata extensions.

This commit is contained in:
Roland Hedberg
2015-11-18 10:38:39 +01:00
parent 0218b0b064
commit 2ce425c84c
6 changed files with 167 additions and 83 deletions

View File

@@ -979,7 +979,7 @@ def extension_elements_to_elements(extension_elements, schemas):
if isinstance(schemas, list): if isinstance(schemas, list):
pass pass
elif isinstance(schemas, dict): elif isinstance(schemas, dict):
schemas = schemas.values() schemas = list(schemas.values())
else: else:
return res return res

View File

@@ -1,19 +1,20 @@
from __future__ import print_function from __future__ import print_function
import hashlib
import logging import logging
import os import os
import sys import sys
import json import json
import six
import requests
import six
from hashlib import sha1 from hashlib import sha1
from os.path import isfile, join from os.path import isfile, join
from saml2.httpbase import HTTPBase from saml2.httpbase import HTTPBase
from saml2.extension.idpdisc import BINDING_DISCO from saml2.extension.idpdisc import BINDING_DISCO
from saml2.extension.idpdisc import DiscoveryResponse from saml2.extension.idpdisc import DiscoveryResponse
from saml2.md import EntitiesDescriptor from saml2.md import EntitiesDescriptor
from saml2.mdie import to_dict from saml2.mdie import to_dict
from saml2 import md from saml2 import md
from saml2 import samlp from saml2 import samlp
from saml2 import SAMLError from saml2 import SAMLError
@@ -67,6 +68,20 @@ ENTITY_CATEGORY_SUPPORT = "http://macedir.org/entity-category-support"
# --------------------------------------------------- # ---------------------------------------------------
def load_extensions():
from saml2 import extension
import pkgutil
package = extension
prefix = package.__name__ + "."
ext_map = {}
for importer, modname, ispkg in pkgutil.iter_modules(package.__path__,
prefix):
module = __import__(modname, fromlist="dummy")
ext_map[module.NAMESPACE] = module
return ext_map
def destinations(srvs): def destinations(srvs):
return [s["location"] for s in srvs] return [s["location"] for s in srvs]
@@ -564,8 +579,8 @@ class InMemoryMetaData(MetaData):
return True return True
node_name = self.node_name \ node_name = self.node_name \
or "%s:%s" % (md.EntitiesDescriptor.c_namespace, or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
md.EntitiesDescriptor.c_tag) md.EntitiesDescriptor.c_tag)
if self.security.verify_signature( if self.security.verify_signature(
txt, node_name=node_name, cert_file=self.cert): txt, node_name=node_name, cert_file=self.cert):
@@ -705,27 +720,31 @@ class MetaDataMDX(InMemoryMetaData):
""" Uses the md protocol to fetch entity information """ Uses the md protocol to fetch entity information
""" """
def __init__(self, entity_transform, onts, attrc, url, security, cert, @staticmethod
http, **kwargs): def sha1_entity_transform(entity_id):
return "{{sha1}}{}".format(
hashlib.sha1(entity_id.encode("utf-8")).hexdigest())
def __init__(self, url, entity_transform=None):
""" """
:params entity_transform: function transforming (e.g. base64 or sha1 :params url: mdx service url
:params entity_transform: function transforming (e.g. base64,
sha1 hash or URL quote
hash) the entity id. It is applied to the entity id before it is hash) the entity id. It is applied to the entity id before it is
concatenated with the request URL sent to the MDX server. concatenated with the request URL sent to the MDX server. Defaults to
:params onts: sha1 transformation.
:params attrc:
:params url:
:params security: SecurityContext()
:params cert:
:params http:
""" """
super(MetaDataMDX, self).__init__(onts, attrc, **kwargs) super(MetaDataMDX, self).__init__(None, None)
self.url = url self.url = url
self.security = security
self.cert = cert if entity_transform:
self.http = http self.entity_transform = entity_transform
self.entity_transform = entity_transform else:
self.entity_transform = MetaDataMDX.sha1_entity_transform
def load(self): def load(self):
# Do nothing
pass pass
def __getitem__(self, item): def __getitem__(self, item):
@@ -733,13 +752,9 @@ class MetaDataMDX(InMemoryMetaData):
return self.entity[item] return self.entity[item]
except KeyError: except KeyError:
mdx_url = "%s/entities/%s" % (self.url, self.entity_transform(item)) mdx_url = "%s/entities/%s" % (self.url, self.entity_transform(item))
response = self.http.send( response = requests.get(mdx_url, headers={
mdx_url, headers={'Accept': SAML_METADATA_CONTENT_TYPE}) 'Accept': SAML_METADATA_CONTENT_TYPE})
if response.status_code == 200: if response.status_code == 200:
node_name = self.node_name \
or "%s:%s" % (md.EntitiesDescriptor.c_namespace,
md.EntitiesDescriptor.c_tag)
_txt = response.text.encode("utf-8") _txt = response.text.encode("utf-8")
if self.parse_and_check_signature(_txt): if self.parse_and_check_signature(_txt):
@@ -748,6 +763,12 @@ class MetaDataMDX(InMemoryMetaData):
logger.info("Response status: %s", response.status_code) logger.info("Response status: %s", response.status_code)
raise KeyError raise KeyError
def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"):
if binding is None:
binding = BINDING_HTTP_REDIRECT
return self.service(entity_id, "idpsso_descriptor",
"single_sign_on_service", binding)
class MetadataStore(MetaData): class MetadataStore(MetaData):
def __init__(self, onts, attrc, config, ca_certs=None, def __init__(self, onts, attrc, config, ca_certs=None,

View File

@@ -58,6 +58,7 @@ from saml2.validate import NotValid
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -160,9 +161,11 @@ class StatusUnknownPrincipal(StatusError):
class StatusUnsupportedBinding(StatusError): class StatusUnsupportedBinding(StatusError):
pass pass
class StatusResponder(StatusError): class StatusResponder(StatusError):
pass pass
STATUSCODE2EXCEPTION = { STATUSCODE2EXCEPTION = {
STATUS_VERSION_MISMATCH: StatusVersionMismatch, STATUS_VERSION_MISMATCH: StatusVersionMismatch,
STATUS_AUTHN_FAILED: StatusAuthnFailed, STATUS_AUTHN_FAILED: StatusAuthnFailed,
@@ -186,6 +189,8 @@ STATUSCODE2EXCEPTION = {
STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding, STATUS_UNSUPPORTED_BINDING: StatusUnsupportedBinding,
STATUS_RESPONDER: StatusResponder, STATUS_RESPONDER: StatusResponder,
} }
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -206,7 +211,8 @@ def for_me(conditions, myself):
if audience.text.strip() == myself: if audience.text.strip() == myself:
return True return True
else: else:
#print("Not for me: %s != %s" % (audience.text.strip(), myself)) # print("Not for me: %s != %s" % (audience.text.strip(),
# myself))
pass pass
return False return False
@@ -336,7 +342,7 @@ class StatusResponse(object):
logger.exception("EXCEPTION: %s", excp) logger.exception("EXCEPTION: %s", excp)
raise raise
#print("<", self.response) # print("<", self.response)
return self._postamble() return self._postamble()
@@ -377,7 +383,7 @@ class StatusResponse(object):
if self.request_id and self.in_response_to and \ if self.request_id and self.in_response_to and \
self.in_response_to != self.request_id: self.in_response_to != self.request_id:
logger.error("Not the id I expected: %s != %s", logger.error("Not the id I expected: %s != %s",
self.in_response_to, self.request_id) self.in_response_to, self.request_id)
return None return None
try: try:
@@ -391,9 +397,9 @@ class StatusResponse(object):
if self.asynchop: if self.asynchop:
if self.response.destination and \ if self.response.destination and \
self.response.destination not in self.return_addrs: self.response.destination not in self.return_addrs:
logger.error("%s not in %s", self.response.destination, logger.error("%s not in %s", self.response.destination,
self.return_addrs) self.return_addrs)
return None return None
assert self.issue_instant_ok() assert self.issue_instant_ok()
@@ -436,7 +442,7 @@ class NameIDMappingResponse(StatusResponse):
request_id=0, asynchop=True): request_id=0, asynchop=True):
StatusResponse.__init__(self, sec_context, return_addrs, timeslack, StatusResponse.__init__(self, sec_context, return_addrs, timeslack,
request_id, asynchop) request_id, asynchop)
self.signature_check = self.sec\ self.signature_check = self.sec \
.correctly_signed_name_id_mapping_response .correctly_signed_name_id_mapping_response
@@ -506,7 +512,7 @@ class AuthnResponse(StatusResponse):
if self.asynchop: if self.asynchop:
if self.in_response_to in self.outstanding_queries: if self.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[self.in_response_to] self.came_from = self.outstanding_queries[self.in_response_to]
#del self.outstanding_queries[self.in_response_to] # del self.outstanding_queries[self.in_response_to]
try: try:
if not self.check_subject_confirmation_in_response_to( if not self.check_subject_confirmation_in_response_to(
self.in_response_to): self.in_response_to):
@@ -632,12 +638,12 @@ class AuthnResponse(StatusResponse):
def read_attribute_statement(self, attr_statem): def read_attribute_statement(self, attr_statem):
logger.debug("Attribute Statement: %s", attr_statem) logger.debug("Attribute Statement: %s", attr_statem)
for aconv in self.attribute_converters: # for aconv in self.attribute_converters:
logger.debug("Converts name format: %s", aconv.name_format) # logger.debug("Converts name format: %s", aconv.name_format)
self.decrypt_attributes(attr_statem) self.decrypt_attributes(attr_statem)
return to_local(self.attribute_converters, attr_statem, return to_local(self.attribute_converters, attr_statem,
self.allow_unknown_attributes) self.allow_unknown_attributes)
def get_identity(self): def get_identity(self):
""" The assertion can contain zero or one attributeStatements """ The assertion can contain zero or one attributeStatements
@@ -650,7 +656,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in _assertion.advice.assertion: for tmp_assertion in _assertion.advice.assertion:
if tmp_assertion.attribute_statement: if tmp_assertion.attribute_statement:
assert len(tmp_assertion.attribute_statement) == 1 assert len(tmp_assertion.attribute_statement) == 1
ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0])) ava.update(self.read_attribute_statement(
tmp_assertion.attribute_statement[0]))
if _assertion.attribute_statement: if _assertion.attribute_statement:
assert len(_assertion.attribute_statement) == 1 assert len(_assertion.attribute_statement) == 1
_attr_statem = _assertion.attribute_statement[0] _attr_statem = _assertion.attribute_statement[0]
@@ -681,7 +688,7 @@ class AuthnResponse(StatusResponse):
if data.in_response_to in self.outstanding_queries: if data.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[ self.came_from = self.outstanding_queries[
data.in_response_to] data.in_response_to]
#del self.outstanding_queries[data.in_response_to] # del self.outstanding_queries[data.in_response_to]
elif self.allow_unsolicited: elif self.allow_unsolicited:
pass pass
else: else:
@@ -690,7 +697,7 @@ class AuthnResponse(StatusResponse):
# recognize # recognize
logger.debug("in response to: '%s'", data.in_response_to) logger.debug("in response to: '%s'", data.in_response_to)
logger.info("outstanding queries: %s", logger.info("outstanding queries: %s",
self.outstanding_queries.keys()) self.outstanding_queries.keys())
raise Exception( raise Exception(
"Combination of session id and requestURI I don't " "Combination of session id and requestURI I don't "
"recall") "recall")
@@ -768,7 +775,8 @@ class AuthnResponse(StatusResponse):
logger.debug("signed") logger.debug("signed")
if not verified and self.do_not_verify is False: if not verified and self.do_not_verify is False:
try: try:
self.sec.check_signature(assertion, class_name(assertion),self.xmlstr) self.sec.check_signature(assertion, class_name(assertion),
self.xmlstr)
except Exception as exc: except Exception as exc:
logger.error("correctly_signed_response: %s", exc) logger.error("correctly_signed_response: %s", exc)
raise raise
@@ -778,10 +786,10 @@ class AuthnResponse(StatusResponse):
logger.debug("assertion keys: %s", assertion.keyswv()) logger.debug("assertion keys: %s", assertion.keyswv())
logger.debug("outstanding_queries: %s", self.outstanding_queries) logger.debug("outstanding_queries: %s", self.outstanding_queries)
#if self.context == "AuthnReq" or self.context == "AttrQuery": # if self.context == "AuthnReq" or self.context == "AttrQuery":
if self.context == "AuthnReq": if self.context == "AuthnReq":
self.authn_statement_ok() self.authn_statement_ok()
# elif self.context == "AttrQuery": # elif self.context == "AttrQuery":
# self.authn_statement_ok(True) # self.authn_statement_ok(True)
if not self.condition_ok(): if not self.condition_ok():
@@ -789,7 +797,7 @@ class AuthnResponse(StatusResponse):
logger.debug("--- Getting Identity ---") logger.debug("--- Getting Identity ---")
#if self.context == "AuthnReq" or self.context == "AttrQuery": # if self.context == "AuthnReq" or self.context == "AttrQuery":
# self.ava = self.get_identity() # self.ava = self.get_identity()
# logger.debug("--- AVA: %s", self.ava) # logger.debug("--- AVA: %s", self.ava)
@@ -805,13 +813,17 @@ class AuthnResponse(StatusResponse):
logger.exception("get subject") logger.exception("get subject")
raise raise
def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False): def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None,
""" Moves the decrypted assertion from the encrypted assertion to a list. verified=False):
""" Moves the decrypted assertion from the encrypted assertion to a
list.
:param encrypted_assertions: A list of encrypted assertions. :param encrypted_assertions: A list of encrypted assertions.
:param decr_txt: The string representation containing the decrypted data. Used when verifying signatures. :param decr_txt: The string representation containing the decrypted
data. Used when verifying signatures.
:param issuer: The issuer of the response. :param issuer: The issuer of the response.
:param verified: If True do not verify signatures, otherwise verify the signature if it exists. :param verified: If True do not verify signatures, otherwise verify
the signature if it exists.
:return: A list of decrypted assertions. :return: A list of decrypted assertions.
""" """
res = [] res = []
@@ -824,7 +836,8 @@ class AuthnResponse(StatusResponse):
if not self.sec.check_signature( if not self.sec.check_signature(
assertion, origdoc=decr_txt, assertion, origdoc=decr_txt,
node_name=class_name(assertion), issuer=issuer): node_name=class_name(assertion), issuer=issuer):
logger.error("Failed to verify signature on '%s'", assertion) logger.error("Failed to verify signature on '%s'",
assertion)
raise SignatureError() raise SignatureError()
res.append(assertion) res.append(assertion)
return res return res
@@ -836,11 +849,12 @@ class AuthnResponse(StatusResponse):
:return: True encrypted data exists otherwise false. :return: True encrypted data exists otherwise false.
""" """
for _assertion in enc_assertions: for _assertion in enc_assertions:
if _assertion.encrypted_data is not None: if _assertion.encrypted_data is not None:
return True return True
def find_encrypt_data_assertion_list(self, _assertions): def find_encrypt_data_assertion_list(self, _assertions):
""" Verifies if a list of assertions contains encrypted data in the advice element. """ Verifies if a list of assertions contains encrypted data in the
advice element.
:param _assertions: A list of assertions. :param _assertions: A list of assertions.
:return: True encrypted data exists otherwise false. :return: True encrypted data exists otherwise false.
@@ -848,12 +862,14 @@ class AuthnResponse(StatusResponse):
for _assertion in _assertions: for _assertion in _assertions:
if _assertion.advice: if _assertion.advice:
if _assertion.advice.encrypted_assertion: if _assertion.advice.encrypted_assertion:
res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion) res = self.find_encrypt_data_assertion(
_assertion.advice.encrypted_assertion)
if res: if res:
return True return True
def find_encrypt_data(self, resp): def find_encrypt_data(self, resp):
""" Verifies if a saml response contains encrypted assertions with encrypted data. """ Verifies if a saml response contains encrypted assertions with
encrypted data.
:param resp: A saml response. :param resp: A saml response.
:return: True encrypted data exists otherwise false. :return: True encrypted data exists otherwise false.
@@ -867,7 +883,8 @@ class AuthnResponse(StatusResponse):
for tmp_assertion in resp.assertion: for tmp_assertion in resp.assertion:
if tmp_assertion.advice: if tmp_assertion.advice:
if tmp_assertion.advice.encrypted_assertion: if tmp_assertion.advice.encrypted_assertion:
res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion) res = self.find_encrypt_data_assertion(
tmp_assertion.advice.encrypted_assertion)
if res: if res:
return True return True
return False return False
@@ -875,7 +892,8 @@ class AuthnResponse(StatusResponse):
def parse_assertion(self, keys=None): def parse_assertion(self, keys=None):
""" Parse the assertions for a saml response. """ Parse the assertions for a saml response.
:param keys: A string representing a RSA key or a list of strings containing RSA keys. :param keys: A string representing a RSA key or a list of strings
containing RSA keys.
:return: True if the assertions are parsed otherwise False. :return: True if the assertions are parsed otherwise False.
""" """
if self.context == "AuthnQuery": if self.context == "AuthnQuery":
@@ -884,12 +902,13 @@ class AuthnResponse(StatusResponse):
else: # This is a saml2int limitation else: # This is a saml2int limitation
try: try:
assert len(self.response.assertion) == 1 or \ assert len(self.response.assertion) == 1 or \
len(self.response.encrypted_assertion) == 1 len(self.response.encrypted_assertion) == 1
except AssertionError: except AssertionError:
raise Exception("No assertion part") raise Exception("No assertion part")
has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion has_encrypted_assertions = self.find_encrypt_data(self.response) #
#if not has_encrypted_assertions and self.response.assertion: # self.response.encrypted_assertion
# if not has_encrypted_assertions and self.response.assertion:
# for tmp_assertion in self.response.assertion: # for tmp_assertion in self.response.assertion:
# if tmp_assertion.advice: # if tmp_assertion.advice:
# if tmp_assertion.advice.encrypted_assertion: # if tmp_assertion.advice.encrypted_assertion:
@@ -912,15 +931,20 @@ class AuthnResponse(StatusResponse):
decr_text_old = decr_text decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys) decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text) resp = samlp.response_from_string(decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text) _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion,
decr_text)
decr_text_old = None decr_text_old = None
while (self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions)) and \ while (self.find_encrypt_data(
resp) or self.find_encrypt_data_assertion_list(
_enc_assertions)) and \
decr_text_old != decr_text: decr_text_old != decr_text:
decr_text_old = decr_text decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys) decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text) resp = samlp.response_from_string(decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) _enc_assertions = self.decrypt_assertions(
#_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) resp.encrypted_assertion, decr_text, verified=True)
# _enc_assertions = self.decrypt_assertions(
# resp.encrypted_assertion, decr_text, verified=True)
all_assertions = _enc_assertions all_assertions = _enc_assertions
if resp.assertion: if resp.assertion:
all_assertions = all_assertions + resp.assertion all_assertions = all_assertions + resp.assertion
@@ -928,9 +952,10 @@ class AuthnResponse(StatusResponse):
for tmp_ass in all_assertions: for tmp_ass in all_assertions:
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion: if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion, advice_res = self.decrypt_assertions(
decr_text, tmp_ass.advice.encrypted_assertion,
tmp_ass.issuer) decr_text,
tmp_ass.issuer)
if tmp_ass.advice.assertion: if tmp_ass.advice.assertion:
tmp_ass.advice.assertion.extend(advice_res) tmp_ass.advice.assertion.extend(advice_res)
else: else:
@@ -1211,7 +1236,7 @@ class AssertionIDResponse(object):
logger.exception("EXCEPTION: %s", excp) logger.exception("EXCEPTION: %s", excp)
raise raise
#print("<", self.response) # print("<", self.response)
return self._postamble() return self._postamble()
@@ -1233,4 +1258,3 @@ class AssertionIDResponse(object):
logger.debug("response: %s", self.response) logger.debug("response: %s", self.response)
return self return self

View File

@@ -2,11 +2,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime import datetime
import re import re
from six.moves.urllib.parse import quote_plus #from six.moves.urllib.parse import quote_plus
from future.backports.urllib.parse import quote_plus
from saml2.config import Config from saml2.config import Config
from saml2.httpbase import HTTPBase from saml2.mdstore import MetadataStore
from saml2.mdstore import MetadataStore, MetaDataMDX from saml2.mdstore import MetaDataMDX
from saml2.mdstore import SAML_METADATA_CONTENT_TYPE
from saml2.mdstore import destinations from saml2.mdstore import destinations
from saml2.mdstore import load_extensions
from saml2.mdstore import name from saml2.mdstore import name
from saml2 import md from saml2 import md
from saml2 import sigver from saml2 import sigver
@@ -18,16 +21,13 @@ from saml2 import saml
from saml2 import config from saml2 import config
from saml2.attribute_converter import ac_factory from saml2.attribute_converter import ac_factory
from saml2.attribute_converter import d_to_local_name from saml2.attribute_converter import d_to_local_name
from saml2.extension import mdui
from saml2.extension import idpdisc
from saml2.extension import dri
from saml2.extension import mdattr
from saml2.extension import ui
from saml2.s_utils import UnknownPrincipal from saml2.s_utils import UnknownPrincipal
from saml2 import xmldsig from saml2 import xmldsig
from saml2 import xmlenc from saml2 import xmlenc
from pathutils import full_path from pathutils import full_path
import responses
sec_config = config.Config() sec_config = config.Config()
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"]) # sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
@@ -88,16 +88,13 @@ TEST_METADATA_STRING = """
ONTS = { ONTS = {
saml.NAMESPACE: saml, saml.NAMESPACE: saml,
mdui.NAMESPACE: mdui,
mdattr.NAMESPACE: mdattr,
dri.NAMESPACE: dri,
ui.NAMESPACE: ui,
idpdisc.NAMESPACE: idpdisc,
md.NAMESPACE: md, md.NAMESPACE: md,
xmldsig.NAMESPACE: xmldsig, xmldsig.NAMESPACE: xmldsig,
xmlenc.NAMESPACE: xmlenc xmlenc.NAMESPACE: xmlenc
} }
ONTS.update(load_extensions())
ATTRCONV = ac_factory(full_path("attributemaps")) ATTRCONV = ac_factory(full_path("attributemaps"))
METADATACONF = { METADATACONF = {
@@ -150,6 +147,10 @@ METADATACONF = {
"class": "saml2.mdstore.InMemoryMetaData", "class": "saml2.mdstore.InMemoryMetaData",
"metadata": [(TEST_METADATA_STRING,)] "metadata": [(TEST_METADATA_STRING,)]
}], }],
"12": [{
"class": "saml2.mdstore.MetaDataFile",
"metadata": [(full_path("uu.xml"),)],
}],
} }
@@ -303,6 +304,36 @@ def test_metadata_file():
assert len(mds.keys()) == 560 assert len(mds.keys()) == 560
@responses.activate
def test_mdx_service():
entity_id = "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
url = "http://mdx.example.com/entities/{}".format(
quote_plus(MetaDataMDX.sha1_entity_transform(entity_id)))
responses.add(responses.GET, url, body=TEST_METADATA_STRING, status=200,
content_type=SAML_METADATA_CONTENT_TYPE)
mdx = MetaDataMDX("http://mdx.example.com")
sso_loc = mdx.service(entity_id, "idpsso_descriptor", "single_sign_on_service")
assert sso_loc[BINDING_HTTP_REDIRECT][0]["location"] == "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
certs = mdx.certs(entity_id, "idpsso")
assert len(certs) == 1
@responses.activate
def test_mdx_single_sign_on_service():
entity_id = "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
url = "http://mdx.example.com/entities/{}".format(
quote_plus(MetaDataMDX.sha1_entity_transform(entity_id)))
responses.add(responses.GET, url, body=TEST_METADATA_STRING, status=200,
content_type=SAML_METADATA_CONTENT_TYPE)
mdx = MetaDataMDX("http://mdx.example.com")
sso_loc = mdx.single_sign_on_service(entity_id, BINDING_HTTP_REDIRECT)
assert sso_loc[0]["location"] == "http://xenosmilus.umdc.umu.se/simplesaml/saml2/idp/metadata.php"
# pyff-test not available # pyff-test not available
# def test_mdx_service(): # def test_mdx_service():
# sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"]) # sec_config.xmlsec_binary = sigver.get_xmlsec_binary(["/opt/local/bin"])
@@ -429,6 +460,12 @@ def test_get_certs_from_metadata_without_keydescriptor():
assert len(certs) == 0 assert len(certs) == 0
def test_metadata_extension_algsupport():
mds = MetadataStore(list(ONTS.values()), ATTRCONV, None)
mds.imp(METADATACONF["12"])
mdf = mds.metadata[full_path("uu.xml")]
_txt = mdf.dumps()
assert mds
if __name__ == "__main__": if __name__ == "__main__":
test_get_certs_from_metadata() test_metadata_extension_algsupport()

View File

@@ -0,0 +1,2 @@
pymongo==3.0.1
responses==0.5.0

View File

@@ -3,5 +3,5 @@ envlist = py27,py34
[testenv] [testenv]
deps = pytest deps = pytest
pymongo==3.0.1 -rtests/test_requirements.txt
commands = py.test tests/ commands = py.test tests/