Added support for one callback. Will be used by the saml2test tool.

This commit is contained in:
rohe
2015-12-18 08:03:18 +01:00
parent b28774d973
commit 75ae22eb09
3 changed files with 82 additions and 59 deletions

View File

@@ -89,7 +89,7 @@ class Base(Entity):
""" The basic pySAML2 service provider class """
def __init__(self, config=None, identity_cache=None, state_cache=None,
virtual_organization="", config_file=""):
virtual_organization="", config_file="", msg_cb=None):
"""
:param config: A saml2.config.Config instance
:param identity_cache: Where the class should store identity information
@@ -97,7 +97,8 @@ class Base(Entity):
:param virtual_organization: A specific virtual organization
"""
Entity.__init__(self, "sp", config, config_file, virtual_organization)
Entity.__init__(self, "sp", config, config_file, virtual_organization,
msg_cb=msg_cb)
self.users = Population(identity_cache)
self.lock = threading.Lock()
@@ -150,7 +151,8 @@ class Base(Entity):
raise IdpUnspecified("Too many IdPs to choose from: %s" % eids)
try:
srvs = self.metadata.single_sign_on_service(list(eids.keys())[0], binding)
srvs = self.metadata.single_sign_on_service(list(eids.keys())[0],
binding)
return destinations(srvs)[0]
except IndexError:
raise IdpUnspecified("No IdP to send to given the premises")
@@ -186,7 +188,7 @@ class Base(Entity):
ava = self.users.get_identity(name_id)[0]
return ava
#noinspection PyUnusedLocal
# noinspection PyUnusedLocal
@staticmethod
def is_session_valid(_session_id):
""" Place holder. Supposed to check if the session is still valid.
@@ -201,11 +203,12 @@ class Base(Entity):
return None
def create_authn_request(self, destination, vorg="", scoping=None,
binding=saml2.BINDING_HTTP_POST,
nameid_format=None,
service_url_binding=None, message_id=0,
consent=None, extensions=None, sign=None,
allow_create=False, sign_prepare=False, sign_alg=None, digest_alg=None, **kwargs):
binding=saml2.BINDING_HTTP_POST,
nameid_format=None,
service_url_binding=None, message_id=0,
consent=None, extensions=None, sign=None,
allow_create=False, sign_prepare=False, sign_alg=None,
digest_alg=None, **kwargs):
""" Creates an authentication request.
:param destination: Where the request should be sent.
@@ -244,7 +247,7 @@ class Base(Entity):
except KeyError:
try:
args["assertion_consumer_service_index"] = str(kwargs[
"assertion_consumer_service_index"])
"assertion_consumer_service_index"])
del kwargs["assertion_consumer_service_index"]
except KeyError:
if service_url_binding is None:
@@ -281,7 +284,6 @@ class Base(Entity):
raise ValueError("%s or wrong type expected %s" % (_item,
param))
try:
args["name_id_policy"] = kwargs["name_id_policy"]
del kwargs["name_id_policy"]
@@ -303,7 +305,6 @@ class Base(Entity):
# NameIDPolicy can only have one format specified
nameid_format = nameid_format[0]
name_id_policy = samlp.NameIDPolicy(allow_create=allow_create,
format=nameid_format)
@@ -334,7 +335,7 @@ class Base(Entity):
sign = self.authn_requests_signed
if (sign and self.sec.cert_handler.generate_cert()) or \
client_crt is not None:
client_crt is not None:
with self.lock:
self.sec.cert_handler.update_cert(True, client_crt)
if client_crt is not None:
@@ -342,16 +343,20 @@ class Base(Entity):
return self._message(AuthnRequest, destination, message_id,
consent, extensions, sign, sign_prepare,
protocol_binding=binding,
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg, **args)
scoping=scoping, nsprefix=nsprefix,
sign_alg=sign_alg, digest_alg=digest_alg,
**args)
return self._message(AuthnRequest, destination, message_id, consent,
extensions, sign, sign_prepare,
protocol_binding=binding,
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg, **args)
scoping=scoping, nsprefix=nsprefix,
sign_alg=sign_alg, digest_alg=digest_alg, **args)
def create_attribute_query(self, destination, name_id=None,
attribute=None, message_id=0, consent=None,
extensions=None, sign=False, sign_prepare=False, sign_alg=None, digest_alg=None,
**kwargs):
attribute=None, message_id=0, consent=None,
extensions=None, sign=False, sign_prepare=False, sign_alg=None,
digest_alg=None,
**kwargs):
""" Constructs an AttributeQuery
:param destination: To whom the query should be sent
@@ -407,15 +412,16 @@ class Base(Entity):
return self._message(AttributeQuery, destination, message_id, consent,
extensions, sign, sign_prepare, subject=subject,
attribute=attribute, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
attribute=attribute, nsprefix=nsprefix,
sign_alg=sign_alg, digest_alg=digest_alg)
# MUST use SOAP for
# AssertionIDRequest, SubjectQuery,
# AuthnQuery, AttributeQuery, or AuthzDecisionQuery
def create_authz_decision_query(self, destination, action,
evidence=None, resource=None, subject=None,
message_id=0, consent=None, extensions=None,
sign=None, sign_alg=None, digest_alg=None, **kwargs):
evidence=None, resource=None, subject=None,
message_id=0, consent=None, extensions=None,
sign=None, sign_alg=None, digest_alg=None, **kwargs):
""" Creates an authz decision query.
:param destination: The IdP endpoint
@@ -433,15 +439,16 @@ class Base(Entity):
return self._message(AuthzDecisionQuery, destination, message_id,
consent, extensions, sign, action=action,
evidence=evidence, resource=resource,
subject=subject, sign_alg=sign_alg, digest_alg=digest_alg, **kwargs)
subject=subject, sign_alg=sign_alg,
digest_alg=digest_alg, **kwargs)
def create_authz_decision_query_using_assertion(self, destination,
assertion, action=None,
resource=None,
subject=None, message_id=0,
consent=None,
extensions=None,
sign=False, nsprefix=None):
assertion, action=None,
resource=None,
subject=None, message_id=0,
consent=None,
extensions=None,
sign=False, nsprefix=None):
""" Makes an authz decision query based on a previously received
Assertion.
@@ -466,9 +473,9 @@ class Base(Entity):
_action = None
return self.create_authz_decision_query(
destination, _action, saml.Evidence(assertion=assertion),
resource, subject, message_id=message_id, consent=consent,
extensions=extensions, sign=sign, nsprefix=nsprefix)
destination, _action, saml.Evidence(assertion=assertion),
resource, subject, message_id=message_id, consent=consent,
extensions=extensions, sign=sign, nsprefix=nsprefix)
@staticmethod
def create_assertion_id_request(assertion_id_refs, **kwargs):
@@ -484,8 +491,9 @@ class Base(Entity):
return 0, assertion_id_refs[0]
def create_authn_query(self, subject, destination=None, authn_context=None,
session_index="", message_id=0, consent=None,
extensions=None, sign=False, nsprefix=None, sign_alg=None, digest_alg=None):
session_index="", message_id=0, consent=None,
extensions=None, sign=False, nsprefix=None, sign_alg=None,
digest_alg=None):
"""
:param subject: The subject its all about as a <Subject> instance
@@ -502,14 +510,15 @@ class Base(Entity):
extensions, sign, subject=subject,
session_index=session_index,
requested_authn_context=authn_context,
nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
nsprefix=nsprefix, sign_alg=sign_alg,
digest_alg=digest_alg)
def create_name_id_mapping_request(self, name_id_policy,
name_id=None, base_id=None,
encrypted_id=None, destination=None,
message_id=0, consent=None,
extensions=None, sign=False,
nsprefix=None, sign_alg=None, digest_alg=None):
name_id=None, base_id=None,
encrypted_id=None, destination=None,
message_id=0, consent=None,
extensions=None, sign=False,
nsprefix=None, sign_alg=None, digest_alg=None):
"""
:param name_id_policy:
@@ -531,22 +540,25 @@ class Base(Entity):
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy, name_id=name_id,
nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
nsprefix=nsprefix, sign_alg=sign_alg,
digest_alg=digest_alg)
elif base_id:
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy, base_id=base_id,
nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
nsprefix=nsprefix, sign_alg=sign_alg,
digest_alg=digest_alg)
else:
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy,
encrypted_id=encrypted_id, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
encrypted_id=encrypted_id, nsprefix=nsprefix,
sign_alg=sign_alg, digest_alg=digest_alg)
# ======== response handling ===========
def parse_authn_request_response(self, xmlstr, binding, outstanding=None,
outstanding_certs=None):
outstanding_certs=None):
""" Deal with an AuthnResponse
:param xmlstr: The reply as a xml string
@@ -554,8 +566,11 @@ class Base(Entity):
:param outstanding: A dictionary with session IDs as keys and
the original web request from the user before redirection
as values.
:param only_identity_in_encrypted_assertion: Must exist an assertion that is not encrypted that contains all
other information like subject and authentication statement.
:param only_identity_in_encrypted_assertion: Must exist an assertion
that is not encrypted that contains all
other information like
subject and
authentication statement.
:return: An response.AuthnResponse or None
"""
@@ -576,7 +591,7 @@ class Base(Entity):
"entity_id": self.config.entityid,
"attribute_converters": self.config.attribute_converters,
"allow_unknown_attributes":
self.config.allow_unknown_attributes,
self.config.allow_unknown_attributes,
}
try:
resp = self._parse_response(xmlstr, AuthnResponse,
@@ -594,12 +609,14 @@ class Base(Entity):
if resp is None:
return None
elif isinstance(resp, AuthnResponse):
if resp.assertion is not None and len(resp.response.encrypted_assertion) == 0:
if resp.assertion is not None and len(
resp.response.encrypted_assertion) == 0:
self.users.add_information_about_person(resp.session_info())
logger.info("--- ADDED person info ----")
pass
else:
logger.error("Response type not supported: %s", saml2.class_name(resp))
logger.error("Response type not supported: %s",
saml2.class_name(resp))
return resp
# ------------------------------------------------------------------------
@@ -607,7 +624,7 @@ class Base(Entity):
# AuthzDecisionQuery all get Response as response
def parse_authz_decision_query_response(self, response,
binding=BINDING_SOAP):
binding=BINDING_SOAP):
""" Verify that the response is OK
"""
kwargs = {"entity_id": self.config.entityid,
@@ -658,7 +675,7 @@ class Base(Entity):
# ------------------- ECP ------------------------------------------------
def create_ecp_authn_request(self, entityid=None, relay_state="",
sign=False, **kwargs):
sign=False, **kwargs):
""" Makes an authentication request.
:param entityid: The entity ID of the IdP to send the request to
@@ -710,7 +727,7 @@ class Base(Entity):
_, location = self.pick_binding("single_sign_on_service",
[_binding], entity_id=entityid)
req_id, authn_req = self.create_authn_request(
location, service_url_binding=BINDING_PAOS, **kwargs)
location, service_url_binding=BINDING_PAOS, **kwargs)
# ----------------------------------------
# The SOAP envelope
@@ -730,8 +747,8 @@ class Base(Entity):
_relay_state = None
for item in rdict["header"]:
if item.c_tag == "RelayState" and\
item.c_namespace == ecp.NAMESPACE:
if item.c_tag == "RelayState" and \
item.c_namespace == ecp.NAMESPACE:
_relay_state = item
response = self.parse_authn_request_response(rdict["body"],
@@ -805,7 +822,7 @@ class Base(Entity):
@staticmethod
def parse_discovery_service_response(url="", query="",
returnIDParam="entityID"):
returnIDParam="entityID"):
"""
Deal with the response url from a Discovery Service

View File

@@ -124,7 +124,7 @@ def create_artifact(entity_id, message_handle, endpoint_index=0):
class Entity(HTTPBase):
def __init__(self, entity_type, config=None, config_file="",
virtual_organization=""):
virtual_organization="", msg_cb=None):
self.entity_type = entity_type
self.users = None
@@ -177,6 +177,8 @@ class Entity(HTTPBase):
else:
self.sourceid = {}
self.msg_cb = msg_cb
def _issuer(self, entityid=None):
""" Return an Issuer instance """
if entityid:
@@ -465,7 +467,6 @@ class Entity(HTTPBase):
kwargs[key] = val
req = request_cls(**kwargs)
reqid = req.id
if destination:
req.destination = destination
@@ -479,6 +480,11 @@ class Entity(HTTPBase):
if nsprefix:
req.register_prefix(nsprefix)
if self.msg_cb:
req = self.msg_cb(req)
reqid = req.id
if sign:
return reqid, self.sign(req, sign_prepare=sign_prepare,
sign_alg=sign_alg, digest_alg=digest_alg)

View File

@@ -74,8 +74,8 @@ class Server(Entity):
""" A class that does things that IdPs or AAs do """
def __init__(self, config_file="", config=None, cache=None, stype="idp",
symkey=""):
Entity.__init__(self, stype, config, config_file)
symkey="", msg_cb=None):
Entity.__init__(self, stype, config, config_file, msg_cb=msg_cb)
self.eptid = None
self.init_config(stype)
self.cache = cache