This commit is contained in:
Roland Hedberg
2014-04-08 15:45:11 +02:00
parent 0f32fccf56
commit d10e9e637e

View File

@@ -8,13 +8,15 @@ from saml2.soap import parse_soap_enveloped_saml_artifact_resolve
from saml2.soap import class_instances_from_soap_enveloped_saml_thingies from saml2.soap import class_instances_from_soap_enveloped_saml_thingies
from saml2.soap import open_soap_envelope from saml2.soap import open_soap_envelope
from saml2 import samlp, SamlBase, SAMLError from saml2 import samlp
from saml2 import SamlBase
from saml2 import SAMLError
from saml2 import saml from saml2 import saml
from saml2 import response from saml2 import response as saml_response
from saml2 import BINDING_URI from saml2 import BINDING_URI
from saml2 import BINDING_HTTP_ARTIFACT from saml2 import BINDING_HTTP_ARTIFACT
from saml2 import BINDING_PAOS from saml2 import BINDING_PAOS
from saml2 import request from saml2 import request as saml_request
from saml2 import soap from saml2 import soap
from saml2 import element_to_extension_element from saml2 import element_to_extension_element
from saml2 import extension_elements_to_elements from saml2 import extension_elements_to_elements
@@ -296,7 +298,16 @@ class Entity(HTTPBase):
return info return info
def unravel(self, txt, binding, msgtype="response"): @staticmethod
def unravel(txt, binding, msgtype="response"):
"""
Will unpack the received text. Depending on the context the original
response may have been transformed before transmission.
:param txt:
:param binding:
:param msgtype:
:return:
"""
#logger.debug("unravel '%s'" % txt) #logger.debug("unravel '%s'" % txt)
if binding not in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST, if binding not in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST,
BINDING_SOAP, BINDING_URI, BINDING_HTTP_ARTIFACT, BINDING_SOAP, BINDING_URI, BINDING_HTTP_ARTIFACT,
@@ -309,7 +320,8 @@ class Entity(HTTPBase):
elif binding == BINDING_HTTP_POST: elif binding == BINDING_HTTP_POST:
xmlstr = base64.b64decode(txt) xmlstr = base64.b64decode(txt)
elif binding == BINDING_SOAP: elif binding == BINDING_SOAP:
func = getattr(soap, "parse_soap_enveloped_saml_%s" % msgtype) func = getattr(soap,
"parse_soap_enveloped_saml_%s" % msgtype)
xmlstr = func(txt) xmlstr = func(txt)
elif binding == BINDING_HTTP_ARTIFACT: elif binding == BINDING_HTTP_ARTIFACT:
xmlstr = base64.b64decode(txt) xmlstr = base64.b64decode(txt)
@@ -320,7 +332,8 @@ class Entity(HTTPBase):
return xmlstr return xmlstr
def parse_soap_message(self, text): @staticmethod
def parse_soap_message(text):
""" """
:param text: The SOAP message :param text: The SOAP message
@@ -330,7 +343,8 @@ class Entity(HTTPBase):
ecp, ecp,
samlp]) samlp])
def unpack_soap_message(self, text): @staticmethod
def unpack_soap_message(text):
""" """
Picks out the parts of the SOAP message, body and headers apart Picks out the parts of the SOAP message, body and headers apart
:param text: The SOAP message :param text: The SOAP message
@@ -438,7 +452,8 @@ class Entity(HTTPBase):
msg.extension_elements = extensions msg.extension_elements = extensions
def _response(self, in_response_to, consumer_url=None, status=None, def _response(self, in_response_to, consumer_url=None, status=None,
issuer=None, sign=False, to_sign=None, encrypt_assertion=False, encrypt_cert=None, **kwargs): issuer=None, sign=False, to_sign=None,
encrypt_assertion=False, encrypt_cert=None, **kwargs):
""" Create a Response. """ Create a Response.
:param in_response_to: The session identifier of the request :param in_response_to: The session identifier of the request
@@ -471,10 +486,13 @@ class Entity(HTTPBase):
if encrypt_assertion: if encrypt_assertion:
sign_class = [(class_name(response), response.id)] sign_class = [(class_name(response), response.id)]
if sign: if sign:
response.signature = pre_signature_part(response.id, self.sec.my_cert, 1) response.signature = pre_signature_part(response.id,
self.sec.my_cert, 1)
cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary)
_, cert_file = make_temp("%s" % encrypt_cert, decode=False) _, cert_file = make_temp("%s" % encrypt_cert, decode=False)
response = cbxs.encrypt_assertion(response, cert_file, pre_encryption_part())#template(response.assertion.id)) response = cbxs.encrypt_assertion(response, cert_file,
pre_encryption_part())
# template(response.assertion.id))
if sign: if sign:
return signed_instance_factory(response, self.sec, sign_class) return signed_instance_factory(response, self.sec, sign_class)
else: else:
@@ -520,13 +538,14 @@ class Entity(HTTPBase):
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def srv2typ(self, service): @staticmethod
def srv2typ(service):
for typ in ["aa", "pdp", "aq"]: for typ in ["aa", "pdp", "aq"]:
if service in ENDPOINTS[typ]: if service in ENDPOINTS[typ]:
if typ == "aa": if typ == "aa":
return "attribute_authority" return "attribute_authority"
elif typ == "aq": elif typ == "aq":
return "authn_authority" return "authn_authority"
else: else:
return typ return typ
@@ -570,12 +589,14 @@ class Entity(HTTPBase):
origdoc = xmlstr origdoc = xmlstr
xmlstr = self.unravel(xmlstr, binding, request_cls.msgtype) xmlstr = self.unravel(xmlstr, binding, request_cls.msgtype)
must = self.config.getattr("want_authn_requests_signed", "idp") must = self.config.getattr("want_authn_requests_signed", "idp")
only_valid_cert = self.config.getattr("want_authn_requests_only_with_valid_cert", "idp") only_valid_cert = self.config.getattr(
"want_authn_requests_only_with_valid_cert", "idp")
if only_valid_cert is None: if only_valid_cert is None:
only_valid_cert = False only_valid_cert = False
if only_valid_cert: if only_valid_cert:
must = True must = True
_request = _request.loads(xmlstr, binding, origdoc=origdoc, must=must, only_valid_cert=only_valid_cert) _request = _request.loads(xmlstr, binding, origdoc=origdoc, must=must,
only_valid_cert=only_valid_cert)
_log_debug("Loaded request") _log_debug("Loaded request")
@@ -674,14 +695,14 @@ class Entity(HTTPBase):
return response return response
def create_artifact_resolve(self, artifact, destination, sid, consent=None, def create_artifact_resolve(self, artifact, destination, sessid,
extensions=None, sign=False): consent=None, extensions=None, sign=False):
""" """
Create a ArtifactResolve request Create a ArtifactResolve request
:param artifact: :param artifact:
:param destination: :param destination:
:param sid: session id :param sessid: session id
:param consent: :param consent:
:param extensions: :param extensions:
:param sign: :param sign:
@@ -690,7 +711,7 @@ class Entity(HTTPBase):
artifact = Artifact(text=artifact) artifact = Artifact(text=artifact)
return self._message(ArtifactResolve, destination, sid, return self._message(ArtifactResolve, destination, sessid,
consent, extensions, sign, artifact=artifact) consent, extensions, sign, artifact=artifact)
def create_artifact_response(self, request, artifact, bindings=None, def create_artifact_response(self, request, artifact, bindings=None,
@@ -763,7 +784,7 @@ class Entity(HTTPBase):
was not. was not.
""" """
return self._parse_request(xmlstr, request.ManageNameIDRequest, return self._parse_request(xmlstr, saml_request.ManageNameIDRequest,
"manage_name_id_service", binding) "manage_name_id_service", binding)
def create_manage_name_id_response(self, request, bindings=None, def create_manage_name_id_response(self, request, bindings=None,
@@ -781,17 +802,20 @@ class Entity(HTTPBase):
def parse_manage_name_id_request_response(self, string, def parse_manage_name_id_request_response(self, string,
binding=BINDING_SOAP): binding=BINDING_SOAP):
return self._parse_response(string, response.ManageNameIDResponse, return self._parse_response(string, saml_response.ManageNameIDResponse,
"manage_name_id_service", binding) "manage_name_id_service", binding)
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def _parse_response(self, xmlstr, response_cls, service, binding, outstanding_certs=None, **kwargs): def _parse_response(self, xmlstr, response_cls, service, binding,
outstanding_certs=None, **kwargs):
""" Deal with a Response """ Deal with a Response
:param xmlstr: The response as a xml string :param xmlstr: The response as a xml string
:param response_cls: What type of response it is :param response_cls: What type of response it is
:param binding: What type of binding this message came through. :param binding: What type of binding this message came through.
:param outstanding_certs: Certificates that belongs to me that the
IdP may have used to encrypt a response/assertion/..
:param kwargs: Extra key word arguments :param kwargs: Extra key word arguments
:return: None if the reply doesn't contain a valid SAML Response, :return: None if the reply doesn't contain a valid SAML Response,
otherwise the response. otherwise the response.
@@ -800,20 +824,20 @@ class Entity(HTTPBase):
response = None response = None
if self.config.accepted_time_diff: if self.config.accepted_time_diff:
kwargs["timeslack"] = self.config.accepted_time_diff timeslack = self.config.accepted_time_diff
if "asynchop" not in kwargs: if "asynchop" not in kwargs:
if binding in [BINDING_SOAP, BINDING_PAOS]: if binding in [BINDING_SOAP, BINDING_PAOS]:
kwargs["asynchop"] = False asynchop = False
else: else:
kwargs["asynchop"] = True asynchop = True
if xmlstr: if xmlstr:
if "return_addrs" not in kwargs: if "return_addrs" not in kwargs:
if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]: if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]:
try: try:
# expected return address # expected return address
kwargs["return_addrs"] = self.config.endpoint( return_addrs = self.config.endpoint(
service, binding=binding) service, binding=binding)
except Exception: except Exception:
logger.info("Not supposed to handle this!") logger.info("Not supposed to handle this!")
@@ -830,7 +854,9 @@ class Entity(HTTPBase):
if outstanding_certs is not None: if outstanding_certs is not None:
_response = samlp.any_response_from_string(xmlstr) _response = samlp.any_response_from_string(xmlstr)
if len(_response.encrypted_assertion) > 0: if len(_response.encrypted_assertion) > 0:
_, cert_file = make_temp("%s" % outstanding_certs[_response.in_response_to]["key"], decode=False) _, cert_file = make_temp(
"%s" % outstanding_certs[
_response.in_response_to]["key"], decode=False)
cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary)
xmlstr = cbxs.decrypt(xmlstr, cert_file) xmlstr = cbxs.decrypt(xmlstr, cert_file)
if not xmlstr: # Not a valid reponse if not xmlstr: # Not a valid reponse
@@ -854,9 +880,11 @@ class Entity(HTTPBase):
if hasattr(response.response, 'encrypted_assertion'): if hasattr(response.response, 'encrypted_assertion'):
for encrypted_assertion in response.response.encrypted_assertion: for encrypted_assertion in response.response.encrypted_assertion:
if encrypted_assertion.extension_elements is not None: if encrypted_assertion.extension_elements is not None:
assertion_list = extension_elements_to_elements(encrypted_assertion.extension_elements, [saml]) assertion_list = extension_elements_to_elements(
encrypted_assertion.extension_elements, [saml])
for assertion in assertion_list: for assertion in assertion_list:
_assertion = saml.assertion_from_string(str(assertion)) _assertion = saml.assertion_from_string(
str(assertion))
response.response.assertion.append(_assertion) response.response.assertion.append(_assertion)
if response: if response:
@@ -887,7 +915,7 @@ class Entity(HTTPBase):
was not. was not.
""" """
return self._parse_request(xmlstr, request.LogoutRequest, return self._parse_request(xmlstr, saml_request.LogoutRequest,
"single_logout_service", binding) "single_logout_service", binding)
def use_artifact(self, message, endpoint_index=0): def use_artifact(self, message, endpoint_index=0):
@@ -964,7 +992,7 @@ class Entity(HTTPBase):
kwargs = {"entity_id": self.config.entityid, kwargs = {"entity_id": self.config.entityid,
"attribute_converters": self.config.attribute_converters} "attribute_converters": self.config.attribute_converters}
resp = self._parse_response(xmlstr, response.ArtifactResponse, resp = self._parse_response(xmlstr, saml_response.ArtifactResponse,
"artifact_resolve", BINDING_SOAP, "artifact_resolve", BINDING_SOAP,
**kwargs) **kwargs)
# should just be one # should just be one