Added the possibility to set signature and digest algorithm on all the functions I identified.

pysaml2 has a default value for sign and digest. To make it possible to always use the same algorithm this default value has been replaced with a singleton class.

 The first time the singleton class is instantiated the sign and digest algorithm will be set. After that it cannot be changed. A good place to setup this single class is in the server setup.

 Example:
        ds.DefaultSignature(ds.SIG_RSA_SHA512, ds.DIGEST_SHA512)
This commit is contained in:
Hans Hörberg
2015-11-06 12:17:34 +01:00
parent 0f209eb549
commit 3b84f65d84
10 changed files with 700 additions and 429 deletions

View File

@@ -43,6 +43,7 @@ tests_require = [
'pytest',
'mako',
'webob',
'mock'
#'pytest-coverage',
]

View File

@@ -135,7 +135,7 @@ class Saml2Client(Base):
raise SignOnError(
"No supported bindings available for authentication")
def global_logout(self, name_id, reason="", expire=None, sign=None):
def global_logout(self, name_id, reason="", expire=None, sign=None, sign_alg=None, digest_alg=None):
""" More or less a layer of indirection :-/
Bootstrapping the whole thing by finding all the IdPs that should
be notified.
@@ -160,10 +160,10 @@ class Saml2Client(Base):
# find out which IdPs/AAs I should notify
entity_ids = self.users.issuers_of_info(name_id)
return self.do_logout(name_id, entity_ids, reason, expire, sign)
return self.do_logout(name_id, entity_ids, reason, expire, sign, sign_alg=sign_alg)
def do_logout(self, name_id, entity_ids, reason, expire, sign=None,
expected_binding=None, **kwargs):
expected_binding=None, sign_alg=None, digest_alg=None, **kwargs):
"""
:param name_id: Identifier of the Subject (a NameID instance)
@@ -226,11 +226,11 @@ class Saml2Client(Base):
key = None
if sign:
if binding == BINDING_HTTP_REDIRECT:
sigalg = kwargs.get("sigalg", ds.sig_default)
sigalg = kwargs.get("sigalg", ds.DefaultSignature().get_sign_alg())
key = kwargs.get("key", self.signkey)
srequest = str(request)
else:
srequest = self.sign(request)
srequest = self.sign(request, sign_alg=sign_alg)
else:
srequest = str(request)
@@ -290,7 +290,7 @@ class Saml2Client(Base):
identity = self.users.get_identity(name_id)[0]
return bool(identity)
def handle_logout_response(self, response):
def handle_logout_response(self, response, sign_alg=None, digest_alg=None):
""" handles a Logout response
:param response: A response.Response instance
@@ -309,10 +309,12 @@ class Saml2Client(Base):
return 0, "200 Ok", [("Content-type", "text/html")], []
else:
status["entity_ids"].remove(issuer)
if "sign_alg" in status:
sign_alg = status["sign_alg"]
return self.do_logout(decode(status["name_id"]),
status["entity_ids"],
status["reason"], status["not_on_or_after"],
status["sign"])
status["sign"], sign_alg=sign_alg)
def _use_soap(self, destination, query_type, **kwargs):
_create_func = getattr(self, "create_%s" % query_type)

View File

@@ -202,7 +202,7 @@ class Base(Entity):
nameid_format=None,
service_url_binding=None, message_id=0,
consent=None, extensions=None, sign=None,
allow_create=False, sign_prepare=False, **kwargs):
allow_create=False, sign_prepare=False, sign_alg=None, digest_alg=None, **kwargs):
""" Creates an authentication request.
:param destination: Where the request should be sent.
@@ -339,15 +339,15 @@ class Base(Entity):
return self._message(AuthnRequest, destination, message_id,
consent, extensions, sign, sign_prepare,
protocol_binding=binding,
scoping=scoping, nsprefix=nsprefix, **args)
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, **args)
return self._message(AuthnRequest, destination, message_id, consent,
extensions, sign, sign_prepare,
protocol_binding=binding,
scoping=scoping, nsprefix=nsprefix, **args)
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, **args)
def create_attribute_query(self, destination, name_id=None,
attribute=None, message_id=0, consent=None,
extensions=None, sign=False, sign_prepare=False,
extensions=None, sign=False, sign_prepare=False, sign_alg=None, digest_alg=None,
**kwargs):
""" Constructs an AttributeQuery
@@ -404,7 +404,7 @@ class Base(Entity):
return self._message(AttributeQuery, destination, message_id, consent,
extensions, sign, sign_prepare, subject=subject,
attribute=attribute, nsprefix=nsprefix)
attribute=attribute, nsprefix=nsprefix, sign_alg=sign_alg)
# MUST use SOAP for
# AssertionIDRequest, SubjectQuery,
@@ -412,7 +412,7 @@ class Base(Entity):
def create_authz_decision_query(self, destination, action,
evidence=None, resource=None, subject=None,
message_id=0, consent=None, extensions=None,
sign=None, **kwargs):
sign=None, sign_alg=None, digest_alg=None, **kwargs):
""" Creates an authz decision query.
:param destination: The IdP endpoint
@@ -430,7 +430,7 @@ class Base(Entity):
return self._message(AuthzDecisionQuery, destination, message_id,
consent, extensions, sign, action=action,
evidence=evidence, resource=resource,
subject=subject, **kwargs)
subject=subject, sign_alg=sign_alg, **kwargs)
def create_authz_decision_query_using_assertion(self, destination,
assertion, action=None,
@@ -482,7 +482,7 @@ class Base(Entity):
def create_authn_query(self, subject, destination=None, authn_context=None,
session_index="", message_id=0, consent=None,
extensions=None, sign=False, nsprefix=None):
extensions=None, sign=False, nsprefix=None, sign_alg=None, digest_alg=None):
"""
:param subject: The subject its all about as a <Subject> instance
@@ -499,14 +499,14 @@ class Base(Entity):
extensions, sign, subject=subject,
session_index=session_index,
requested_authn_context=authn_context,
nsprefix=nsprefix)
nsprefix=nsprefix, sign_alg=sign_alg)
def create_name_id_mapping_request(self, name_id_policy,
name_id=None, base_id=None,
encrypted_id=None, destination=None,
message_id=0, consent=None,
extensions=None, sign=False,
nsprefix=None):
nsprefix=None, sign_alg=None, digest_alg=None):
"""
:param name_id_policy:
@@ -528,17 +528,17 @@ class Base(Entity):
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy, name_id=name_id,
nsprefix=nsprefix)
nsprefix=nsprefix, sign_alg=sign_alg)
elif base_id:
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy, base_id=base_id,
nsprefix=nsprefix)
nsprefix=nsprefix, sign_alg=sign_alg)
else:
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy,
encrypted_id=encrypted_id, nsprefix=nsprefix)
encrypted_id=encrypted_id, nsprefix=nsprefix, sign_alg=sign_alg)
# ======== response handling ===========

View File

@@ -409,9 +409,9 @@ class Entity(HTTPBase):
# --------------------------------------------------------------------------
def sign(self, msg, mid=None, to_sign=None, sign_prepare=False):
def sign(self, msg, mid=None, to_sign=None, sign_prepare=False, sign_alg=None, digest_alg=None):
if msg.signature is None:
msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1)
msg.signature = pre_signature_part(msg.id, self.sec.my_cert, 1, sign_alg=sign_alg)
if sign_prepare:
return msg
@@ -429,7 +429,7 @@ class Entity(HTTPBase):
def _message(self, request_cls, destination=None, message_id=0,
consent=None, extensions=None, sign=False, sign_prepare=False,
nsprefix=None, **kwargs):
nsprefix=None, sign_alg=None, digest_alg=None, **kwargs):
"""
Some parameters appear in all requests so simplify by doing
it in one place
@@ -468,7 +468,7 @@ class Entity(HTTPBase):
req.register_prefix(nsprefix)
if sign:
return reqid, self.sign(req, sign_prepare=sign_prepare)
return reqid, self.sign(req, sign_prepare=sign_prepare, sign_alg=sign_alg)
else:
logger.info("REQUEST: %s" % req)
return reqid, req
@@ -559,8 +559,10 @@ class Entity(HTTPBase):
def _response(self, in_response_to, consumer_url=None, status=None,
issuer=None, sign=False, to_sign=None, sp_entity_id=None,
encrypt_assertion=False, encrypt_assertion_self_contained=False, encrypted_advice_attributes=False,
encrypt_cert_advice=None, encrypt_cert_assertion=None,sign_assertion=None, pefim=False, **kwargs):
encrypt_assertion=False, encrypt_assertion_self_contained=False,
encrypted_advice_attributes=False,
encrypt_cert_advice=None, encrypt_cert_assertion=None,sign_assertion=None,
pefim=False, sign_alg=None, digest_alg=None, **kwargs):
""" Create a Response.
Encryption:
encrypt_assertion must be true for encryption to be performed. If encrypted_advice_attributes also is
@@ -596,7 +598,7 @@ class Entity(HTTPBase):
response = response_factory(issuer=_issuer,
in_response_to=in_response_to,
status=status)
status=status, sign_alg=sign_alg)
if consumer_url:
response.destination = consumer_url
@@ -616,7 +618,7 @@ class Entity(HTTPBase):
len(response.assertion.advice.assertion) == 1):
if sign:
response.signature = pre_signature_part(response.id,
self.sec.my_cert, 1)
self.sec.my_cert, 1, sign_alg=sign_alg)
sign_class = [(class_name(response), response.id)]
cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary)
encrypt_advice = False
@@ -635,7 +637,9 @@ class Entity(HTTPBase):
for tmp_assertion in _advice_assertions:
to_sign_advice = []
if sign_assertion and not pefim:
tmp_assertion.signature = pre_signature_part(tmp_assertion.id, self.sec.my_cert, 1)
tmp_assertion.signature = pre_signature_part(tmp_assertion.id,
self.sec.my_cert, 1,
sign_alg=sign_alg)
to_sign_advice.append((class_name(tmp_assertion), tmp_assertion.id))
#tmp_assertion = response.assertion.advice.assertion[0]
_assertion.advice.encrypted_assertion[0].add_extension_element(tmp_assertion)
@@ -661,7 +665,9 @@ class Entity(HTTPBase):
if not isinstance(_assertions, list):
_assertions = [_assertions]
for _assertion in _assertions:
_assertion.signature = pre_signature_part(_assertion.id, self.sec.my_cert, 1)
_assertion.signature = pre_signature_part(_assertion.id, self.sec.my_cert,
1,
sign_alg=sign_alg)
to_sign_assertion.append((class_name(_assertion), _assertion.id))
if encrypt_assertion_self_contained:
try:
@@ -685,11 +691,11 @@ class Entity(HTTPBase):
return response
if sign:
return self.sign(response, to_sign=to_sign)
return self.sign(response, to_sign=to_sign, sign_alg=sign_alg)
else:
return response
def _status_response(self, response_class, issuer, status, sign=False,
def _status_response(self, response_class, issuer, status, sign=False, sign_alg=None, digest_alg=None,
**kwargs):
""" Create a StatusResponse.
@@ -718,7 +724,7 @@ class Entity(HTTPBase):
status=status, **kwargs)
if sign:
return self.sign(response, mid)
return self.sign(response, mid, sign_alg=sign_alg)
else:
return response
@@ -797,7 +803,7 @@ class Entity(HTTPBase):
# ------------------------------------------------------------------------
def create_error_response(self, in_response_to, destination, info,
sign=False, issuer=None, **kwargs):
sign=False, issuer=None, sign_alg=None, digest_alg=None, **kwargs):
""" Create a error response.
:param in_response_to: The identifier of the message this is a response
@@ -813,7 +819,7 @@ class Entity(HTTPBase):
status = error_status_factory(info)
return self._response(in_response_to, destination, status, issuer,
sign)
sign, sign_alg=sign_alg)
# ------------------------------------------------------------------------
@@ -821,7 +827,7 @@ class Entity(HTTPBase):
subject_id=None, name_id=None,
reason=None, expire=None, message_id=0,
consent=None, extensions=None, sign=False,
session_indexes=None):
session_indexes=None, sign_alg=None, digest_alg=None):
""" Constructs a LogoutRequest
:param destination: Destination of the request
@@ -865,10 +871,10 @@ class Entity(HTTPBase):
return self._message(LogoutRequest, destination, message_id,
consent, extensions, sign, name_id=name_id,
reason=reason, not_on_or_after=expire,
issuer=self._issuer(), **args)
issuer=self._issuer(), sign_alg=sign_alg, **args)
def create_logout_response(self, request, bindings=None, status=None,
sign=False, issuer=None):
sign=False, issuer=None, sign_alg=None, digest_alg=None):
""" Create a LogoutResponse.
:param request: The request this is a response to
@@ -886,14 +892,14 @@ class Entity(HTTPBase):
issuer = self._issuer()
response = self._status_response(samlp.LogoutResponse, issuer, status,
sign, **rinfo)
sign, sign_alg=sign_alg, **rinfo)
logger.info("Response: %s" % (response,))
return response
def create_artifact_resolve(self, artifact, destination, sessid,
consent=None, extensions=None, sign=False):
consent=None, extensions=None, sign=False, sign_alg=None, digest_alg=None):
"""
Create a ArtifactResolve request
@@ -909,10 +915,10 @@ class Entity(HTTPBase):
artifact = Artifact(text=artifact)
return self._message(ArtifactResolve, destination, sessid,
consent, extensions, sign, artifact=artifact)
consent, extensions, sign, artifact=artifact, sign_alg=sign_alg)
def create_artifact_response(self, request, artifact, bindings=None,
status=None, sign=False, issuer=None):
status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None):
"""
Create an ArtifactResponse
:return:
@@ -920,7 +926,7 @@ class Entity(HTTPBase):
rinfo = self.response_args(request, bindings)
response = self._status_response(ArtifactResponse, issuer, status,
sign=sign, **rinfo)
sign=sign, sign_alg=sign_alg, **rinfo)
msg = element_to_extension_element(self.artifact[artifact])
response.extension_elements = [msg]
@@ -933,7 +939,7 @@ class Entity(HTTPBase):
consent=None, extensions=None, sign=False,
name_id=None, new_id=None,
encrypted_id=None, new_encrypted_id=None,
terminate=None):
terminate=None, sign_alg=None, digest_alg=None):
"""
:param destination:
@@ -969,7 +975,7 @@ class Entity(HTTPBase):
"One of NewID, NewEncryptedNameID or Terminate has to be provided")
return self._message(ManageNameIDRequest, destination, consent=consent,
extensions=extensions, sign=sign, **kwargs)
extensions=extensions, sign=sign, sign_alg=sign_alg, **kwargs)
def parse_manage_name_id_request(self, xmlstr, binding=BINDING_SOAP):
""" Deal with a LogoutRequest
@@ -985,13 +991,13 @@ class Entity(HTTPBase):
"manage_name_id_service", binding)
def create_manage_name_id_response(self, request, bindings=None,
status=None, sign=False, issuer=None,
status=None, sign=False, issuer=None, sign_alg=None, digest_alg=None,
**kwargs):
rinfo = self.response_args(request, bindings)
response = self._status_response(samlp.ManageNameIDResponse, issuer,
status, sign, **rinfo)
status, sign, sign_alg=sign_alg, **rinfo)
logger.info("Response: %s" % (response,))

View File

@@ -754,7 +754,7 @@ def entity_descriptor(confd):
return entd
def entities_descriptor(eds, valid_for, name, ident, sign, secc):
def entities_descriptor(eds, valid_for, name, ident, sign, secc, sign_alg=None, digest_alg=None):
entities = md.EntitiesDescriptor(entity_descriptor=eds)
if valid_for:
entities.valid_until = in_a_while(hours=valid_for)
@@ -775,7 +775,7 @@ def entities_descriptor(eds, valid_for, name, ident, sign, secc):
raise SAMLError("If you want to do signing you should define " +
"where your public key are")
entities.signature = pre_signature_part(ident, secc.my_cert, 1)
entities.signature = pre_signature_part(ident, secc.my_cert, 1, sign_alg=sign_alg)
entities.id = ident
xmldoc = secc.sign_statement("%s" % entities, class_name(entities))
entities = md.entities_descriptor_from_string(xmldoc)
@@ -785,7 +785,7 @@ def entities_descriptor(eds, valid_for, name, ident, sign, secc):
return entities, xmldoc
def sign_entity_descriptor(edesc, ident, secc):
def sign_entity_descriptor(edesc, ident, secc, sign_alg=None, digest_alg=None):
"""
:param edesc: EntityDescriptor instance
@@ -797,7 +797,7 @@ def sign_entity_descriptor(edesc, ident, secc):
if not ident:
ident = sid()
edesc.signature = pre_signature_part(ident, secc.my_cert, 1)
edesc.signature = pre_signature_part(ident, secc.my_cert, 1, sign_alg=sign_alg)
edesc.id = ident
xmldoc = secc.sign_statement("%s" % edesc, class_name(edesc))
edesc = md.entity_descriptor_from_string(xmldoc)

View File

@@ -332,7 +332,8 @@ class Server(Entity):
sign_assertion=False, sign_response=False,
best_effort=False, encrypt_assertion=False,
encrypt_cert_advice=None, encrypt_cert_assertion=None, authn_statement=None,
encrypt_assertion_self_contained=False, encrypted_advice_attributes=False, pefim=False):
encrypt_assertion_self_contained=False, encrypted_advice_attributes=False, pefim=False,
sign_alg=None, digest_alg=None):
""" Create a response. A layer of indirection.
:param in_response_to: The session identifier of the request
@@ -397,7 +398,8 @@ class Server(Entity):
to_sign = []
if not encrypt_assertion:
if sign_assertion:
assertion.signature = pre_signature_part(assertion.id, self.sec.my_cert, 1)
assertion.signature = pre_signature_part(assertion.id, self.sec.my_cert, 1,
sign_alg=sign_alg)
to_sign.append((class_name(assertion), assertion.id))
#if not encrypted_advice_attributes:
@@ -420,12 +422,14 @@ class Server(Entity):
self.session_db.store_assertion(assertion, to_sign)
return self._response(in_response_to, consumer_url, status, issuer,
sign_response, to_sign,sp_entity_id=sp_entity_id, encrypt_assertion=encrypt_assertion,
sign_response, to_sign,sp_entity_id=sp_entity_id,
encrypt_assertion=encrypt_assertion,
encrypt_cert_advice=encrypt_cert_advice,
encrypt_cert_assertion=encrypt_cert_assertion,
encrypt_assertion_self_contained=encrypt_assertion_self_contained,
encrypted_advice_attributes=encrypted_advice_attributes,sign_assertion=sign_assertion,
pefim=pefim,
encrypted_advice_attributes=encrypted_advice_attributes,
sign_assertion=sign_assertion,
pefim=pefim, sign_alg=sign_alg,
**args)
# ------------------------------------------------------------------------
@@ -435,7 +439,7 @@ class Server(Entity):
sp_entity_id, userid="", name_id=None,
status=None, issuer=None,
sign_assertion=False, sign_response=False,
attributes=None, **kwargs):
attributes=None, sign_alg=None, digest_alg=None, **kwargs):
""" Create an attribute assertion response.
:param identity: A dictionary with attributes and values that are
@@ -485,14 +489,14 @@ class Server(Entity):
if sign_assertion:
assertion.signature = pre_signature_part(assertion.id,
self.sec.my_cert, 1)
self.sec.my_cert, 1, sign_alg=sign_alg)
# Just the assertion or the response and the assertion ?
to_sign = [(class_name(assertion), assertion.id)]
args["assertion"] = assertion
return self._response(in_response_to, destination, status, issuer,
sign_response, to_sign, **args)
sign_response, to_sign, sign_alg=sign_alg, **args)
# ------------------------------------------------------------------------
@@ -502,7 +506,7 @@ class Server(Entity):
sign_response=None, sign_assertion=None,
encrypt_cert_advice=None, encrypt_cert_assertion=None, encrypt_assertion=None,
encrypt_assertion_self_contained=True,
encrypted_advice_attributes=False, pefim=False,
encrypted_advice_attributes=False, pefim=False, sign_alg=None, digest_alg=None,
**kwargs):
""" Constructs an AuthenticationResponse
@@ -644,7 +648,8 @@ class Server(Entity):
encrypted_advice_attributes=encrypted_advice_attributes,
encrypt_cert_advice=encrypt_cert_advice,
encrypt_cert_assertion=encrypt_cert_assertion,
pefim=pefim)
pefim=pefim,
sign_alg=sign_alg)
return self._authn_response(in_response_to, # in_response_to
destination, # consumer_url
sp_entity_id, # sp_entity_id
@@ -661,7 +666,8 @@ class Server(Entity):
encrypted_advice_attributes=encrypted_advice_attributes,
encrypt_cert_advice=encrypt_cert_advice,
encrypt_cert_assertion=encrypt_cert_assertion,
pefim=pefim)
pefim=pefim,
sign_alg=sign_alg)
except MissingValue as exc:
return self.create_error_response(in_response_to, destination,
@@ -681,7 +687,7 @@ class Server(Entity):
authn_decl=authn_decl)
#noinspection PyUnusedLocal
def create_assertion_id_request_response(self, assertion_id, sign=False,
def create_assertion_id_request_response(self, assertion_id, sign=False, sign_alg=None, digest_alg=None,
**kwargs):
"""
@@ -698,7 +704,7 @@ class Server(Entity):
if to_sign:
if assertion.signature is None:
assertion.signature = pre_signature_part(assertion.id,
self.sec.my_cert, 1)
self.sec.my_cert, 1, sign_alg=sign_alg)
return signed_instance_factory(assertion, self.sec, to_sign)
else:
@@ -708,7 +714,7 @@ class Server(Entity):
def create_name_id_mapping_response(self, name_id=None, encrypted_id=None,
in_response_to=None,
issuer=None, sign_response=False,
status=None, **kwargs):
status=None, sign_alg=None, digest_alg=None, **kwargs):
"""
protocol for mapping a principal's name identifier into a
different name identifier for the same principal.
@@ -730,7 +736,7 @@ class Server(Entity):
in_response_to=in_response_to, **ms_args)
if sign_response:
return self.sign(_resp)
return self.sign(_resp, sign_alg=sign_alg)
else:
logger.info("Message: %s" % _resp)
return _resp
@@ -738,7 +744,7 @@ class Server(Entity):
def create_authn_query_response(self, subject, session_index=None,
requested_context=None, in_response_to=None,
issuer=None, sign_response=False,
status=None, **kwargs):
status=None, sign_alg=None, digest_alg=None, **kwargs):
"""
A successful <Response> will contain one or more assertions containing
authentication statements.
@@ -759,7 +765,7 @@ class Server(Entity):
args = {}
return self._response(in_response_to, "", status, issuer,
sign_response, to_sign=[], **args)
sign_response, to_sign=[], sign_alg=sign_alg, **args)
# ---------

View File

@@ -1761,7 +1761,7 @@ class SecurityContext(object):
return self.sign_statement(statement, class_name(
samlp.AttributeQuery()), **kwargs)
def multiple_signatures(self, statement, to_sign, key=None, key_file=None):
def multiple_signatures(self, statement, to_sign, key=None, key_file=None, sign_alg=None, digest_alg=None):
"""
Sign multiple parts of a statement
@@ -1780,7 +1780,7 @@ class SecurityContext(object):
sid = item.id
if not item.signature:
item.signature = pre_signature_part(sid, self.cert_file)
item.signature = pre_signature_part(sid, self.cert_file, sign_alg=sign_alg)
statement = self.sign_statement(statement, class_name(item),
key=key, key_file=key_file,
@@ -1806,9 +1806,9 @@ def pre_signature_part(ident, public_key=None, identifier=None,
"""
if not digest_alg:
digest_alg=ds.digest_default
digest_alg = ds.DefaultSignature().get_digest_alg()
if not sign_alg:
sign_alg=ds.sig_default
sign_alg = ds.DefaultSignature().get_sign_alg()
signature_method = ds.SignatureMethod(algorithm=sign_alg)
canonicalization_method = ds.CanonicalizationMethod(
algorithm=ds.ALG_EXC_C14N)
@@ -1918,12 +1918,12 @@ def pre_encrypt_assertion(response):
return response
def response_factory(sign=False, encrypt=False, **kwargs):
def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None, **kwargs):
response = samlp.Response(id=sid(), version=VERSION,
issue_instant=instant())
if sign:
response.signature = pre_signature_part(kwargs["id"])
response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg)
if encrypt:
pass

File diff suppressed because it is too large Load Diff

View File

@@ -32,6 +32,7 @@ from saml2 import BINDING_HTTP_REDIRECT
from py.test import raises
from pathutils import full_path
import saml2.xmldsig as ds
nid = NameID(name_qualifier="foo", format=NAMEID_FORMAT_TRANSIENT,
text="123456")
@@ -86,6 +87,7 @@ def generate_cert():
class TestServer1():
def setup_class(self):
self.server = Server("idp_conf")

View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from saml2.authn_context import INTERNETPROTOCOLPASSWORD
from saml2.saml import NameID, NAMEID_FORMAT_TRANSIENT
from saml2.samlp import response_from_string
from saml2.server import Server
from saml2 import client
from saml2 import config
from mock.mock import Mock, MagicMock
import saml2.xmldsig as ds
nid = NameID(name_qualifier="foo", format=NAMEID_FORMAT_TRANSIENT,
text="123456")
AUTHN = {
"class_ref": INTERNETPROTOCOLPASSWORD,
"authn_auth": "http://www.example.com/login"
}
def _eq(l1, l2):
return set(l1) == set(l2)
BASEDIR = os.path.abspath(os.path.dirname(__file__))
def get_ava(assertion):
ava = {}
for statement in assertion.attribute_statement:
for attr in statement.attribute:
value = []
for tmp_val in attr.attribute_value:
value.append(tmp_val.text)
key = attr.friendly_name
if key is None or len(key) == 0:
key = attr.text
ava[key] = value
return ava
class TestSignedResponse():
def setup_class(self):
self.server = Server("idp_conf")
sign_alg = Mock()
sign_alg.return_value = ds.SIG_RSA_SHA512
digest_alg = Mock()
digest_alg.return_value = ds.DIGEST_SHA512
self.restet_default = ds.DefaultSignature
ds.DefaultSignature = MagicMock()
ds.DefaultSignature().get_sign_alg = sign_alg
ds.DefaultSignature().get_digest_alg = digest_alg
conf = config.SPConfig()
conf.load_file("server_conf")
self.client = client.Saml2Client(conf)
self.name_id = self.server.ident.transient_nameid(
"urn:mace:example.com:saml:roland:sp", "id12")
self.ava = {"givenName": ["Derek"], "surName": ["Jeter"],
"mail": ["derek@nyy.mlb.com"], "title": "The man"}
def teardown_class(self):
ds.DefaultSignature = self.restet_default
self.server.close()
def verify_assertion(self, assertion):
assert assertion
assert assertion[0].attribute_statement
ava = ava = get_ava(assertion[0])
assert ava ==\
{'mail': ['derek@nyy.mlb.com'], 'givenName': ['Derek'],
'surName': ['Jeter'], 'title': ['The man']}
def test_signed_response(self):
print(ds.DefaultSignature().get_digest_alg())
name_id = self.server.ident.transient_nameid(
"urn:mace:example.com:saml:roland:sp", "id12")
ava = {"givenName": ["Derek"], "surName": ["Jeter"],
"mail": ["derek@nyy.mlb.com"], "title": "The man"}
signed_resp = self.server.create_authn_response(
ava,
"id12", # in_response_to
"http://lingon.catalogix.se:8087/", # consumer_url
"urn:mace:example.com:saml:roland:sp", # sp_entity_id
name_id=name_id,
sign_assertion=True
)
print(signed_resp)
assert signed_resp
sresponse = response_from_string(signed_resp)
assert ds.SIG_RSA_SHA512 in str(sresponse), "Not correctly signed!"
assert ds.DIGEST_SHA512 in str(sresponse), "Not correctly signed!"
def test_signed_response_1(self):
signed_resp = self.server.create_authn_response(
self.ava,
"id12", # in_response_to
"http://lingon.catalogix.se:8087/", # consumer_url
"urn:mace:example.com:saml:roland:sp", # sp_entity_id
name_id=self.name_id,
sign_response=True,
sign_assertion=True,
)
sresponse = response_from_string(signed_resp)
assert ds.SIG_RSA_SHA512 in str(sresponse), "Not correctly signed!"
assert ds.DIGEST_SHA512 in str(sresponse), "Not correctly signed!"
valid = self.server.sec.verify_signature(signed_resp,
self.server.config.cert_file,
node_name='urn:oasis:names:tc:SAML:2.0:protocol:Response',
node_id=sresponse.id,
id_attr="")
assert valid
assert ds.SIG_RSA_SHA512 in str(sresponse.assertion[0]), "Not correctly signed!"
assert ds.DIGEST_SHA512 in str(sresponse.assertion[0]), "Not correctly signed!"
valid = self.server.sec.verify_signature(signed_resp,
self.server.config.cert_file,
node_name='urn:oasis:names:tc:SAML:2.0:assertion:Assertion',
node_id=sresponse.assertion[0].id,
id_attr="")
assert valid
self.verify_assertion(sresponse.assertion)
def test_signed_response_2(self):
signed_resp = self.server.create_authn_response(
self.ava,
"id12", # in_response_to
"http://lingon.catalogix.se:8087/", # consumer_url
"urn:mace:example.com:saml:roland:sp", # sp_entity_id
name_id=self.name_id,
sign_response=True,
sign_assertion=True,
sign_alg=ds.SIG_RSA_SHA256,
digest_alg=ds.DIGEST_SHA512
)
sresponse = response_from_string(signed_resp)
assert ds.SIG_RSA_SHA256 in str(sresponse), "Not correctly signed!"
assert ds.DIGEST_SHA512 in str(sresponse), "Not correctly signed!"
valid = self.server.sec.verify_signature(signed_resp,
self.server.config.cert_file,
node_name='urn:oasis:names:tc:SAML:2.0:protocol:Response',
node_id=sresponse.id,
id_attr="")
assert valid
assert ds.SIG_RSA_SHA256 in str(sresponse.assertion[0]), "Not correctly signed!"
assert ds.DIGEST_SHA512 in str(sresponse.assertion[0]), "Not correctly signed!"
valid = self.server.sec.verify_signature(signed_resp,
self.server.config.cert_file,
node_name='urn:oasis:names:tc:SAML:2.0:assertion:Assertion',
node_id=sresponse.assertion[0].id,
id_attr="")
assert valid
self.verify_assertion(sresponse.assertion)
if __name__ == "__main__":
ts = TestSignedResponse()
ts.setup_class()
ts.test_signed_response()
ts.test_signed_response_1()
ts.test_signed_response_2()