More IdP discovery support.

This commit is contained in:
Roland Hedberg
2013-01-28 11:25:13 +01:00
parent a1ebb808c8
commit d3ec4b1d55
4 changed files with 62 additions and 10 deletions

View File

@@ -49,7 +49,7 @@ class DiscoveryServer(Entity):
return dsr
# -----------------------------------------------------------------------------
# -------------------------------------------------------------------------
def create_discovery_service_response(self, url, IDparam="entityID",
entity_id=None):
@@ -65,4 +65,10 @@ class DiscoveryServer(Entity):
return url
def verify_sp_in_metadata(self, entity_id):
if self.metadata:
endp = self.metadata.discovery_response(entity_id)
if endp:
return True
return False

View File

@@ -8,6 +8,7 @@ import saml2
from saml2 import md
NAMESPACE = 'urn:oasis:names:tc:SAML:profiles:SSO:idp-discovery-protocol'
BINDING_DISCO = "urn:oasis:names:tc:SAML:profiles:SSO:idp-discovery-protocol"
class DiscoveryResponse(md.IndexedEndpointType_):
"""The urn:oasis:names:tc:SAML:profiles:SSO:idp-discovery-protocol:DiscoveryResponse element """

View File

@@ -4,6 +4,8 @@ import sys
import json
from hashlib import sha1
from saml2.extension.idpdisc import BINDING_DISCO
from saml2.extension.idpdisc import DiscoveryResponse
from saml2.mdie import to_dict
@@ -39,6 +41,7 @@ REQ2SRV = {
# SP
"assertion_response": "assertion_consumer_service",
"attribute_response": "attribute_consuming_service",
"discovery_service_request": "discovery_response"
}
def destinations(srvs):
@@ -198,6 +201,25 @@ class MetaData(object):
res[srv["binding"]] = [srv]
return res
def _ext_service(self, entity_id, typ, service, binding):
try:
srvs = self.entity[entity_id][typ]
except KeyError:
return None
if not srvs:
return srvs
res = []
for srv in srvs:
if "extensions" in srv:
for elem in srv["extensions"]["extension_elements"]:
if elem["__class__"] == service:
if elem["binding"] == binding:
res.append(elem)
return res
def any(self, typ, service, binding=None):
"""
Return any entity that matches the specification
@@ -363,6 +385,13 @@ class MetadataStore(object):
return srvs
return []
def _ext_service(self, entity_id, typ, service, binding=None):
for key, md in self.metadata.items():
srvs = md._ext_service(entity_id, typ, service, binding)
if srvs:
return srvs
return []
def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"):
# IDP
@@ -433,20 +462,28 @@ class MetadataStore(object):
return self._service(entity_id, "%s_descriptor" % typ,
"artifact_resolution_service", binding)
def assertion_consumer_service(self, entity_id, binding=None, typ="spsso"):
def assertion_consumer_service(self, entity_id, binding=None, _="spsso"):
# SP
if binding is None:
binding = BINDING_HTTP_POST
return self._service(entity_id, "spsso_descriptor",
"assertion_consumer_service", binding)
def attribute_consuming_service(self, entity_id, binding=None, typ="spsso"):
def attribute_consuming_service(self, entity_id, binding=None, _="spsso"):
# SP
if binding is None:
binding = BINDING_HTTP_REDIRECT
return self._service(entity_id, "spsso_descriptor",
"attribute_consuming_service", binding)
def discovery_response(self, entity_id, binding=None, _="spsso"):
if binding is None:
binding = BINDING_DISCO
return self._ext_service(entity_id, "spsso_descriptor",
"%s&%s" % (DiscoveryResponse.c_namespace,
DiscoveryResponse.c_tag),
binding)
def attribute_requirement(self, entity_id, index=0):
for md in self.metadata.values():
if entity_id in md.entity:

View File

@@ -263,6 +263,12 @@ ENDPOINTS = {
}
}
ENDPOINT_EXT = {
"sp": {
"discovery_response": (idpdisc.DiscoveryResponse, True)
}
}
DEFAULT_BINDING = {
"assertion_consumer_service": BINDING_HTTP_POST,
"single_sign_on_service": BINDING_HTTP_REDIRECT,
@@ -317,10 +323,17 @@ def do_spsso_descriptor(conf, cert=None):
endps = conf.getattr("endpoints", "sp")
if endps:
for (endpoint, instlist) in do_endpoints(endps,
ENDPOINTS["sp"]).items():
for (endpoint, instlist) in do_endpoints(endps, ENDPOINTS["sp"]).items():
setattr(spsso, endpoint, instlist)
ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
if ext:
if spsso.extensions is None:
spsso.extensions = md.Extensions()
for vals in ext.values():
for val in vals:
spsso.extensions.add_extension_element(val)
if cert:
spsso.key_descriptor = do_key_descriptor(cert)
@@ -373,11 +386,6 @@ def do_spsso_descriptor(conf, cert=None):
# except KeyError:
# pass
dresp = conf.getattr("discovery_response", "sp")
if dresp:
if spsso.extensions is None:
spsso.extensions = md.Extensions()
spsso.extensions.add_extension_element(do_idpdisc(dresp))
return spsso