Merge remote-tracking branch 'upstream/master'

# Conflicts:
#	src/saml2/entity.py

digest algorithm added to the same functions as sign alg.
This commit is contained in:
Hans Hörberg
2015-11-19 10:36:56 +01:00
parent bc93176fa6
commit 361b29f464
7 changed files with 29 additions and 29 deletions

View File

@@ -160,7 +160,7 @@ class Saml2Client(Base):
# find out which IdPs/AAs I should notify
entity_ids = self.users.issuers_of_info(name_id)
return self.do_logout(name_id, entity_ids, reason, expire, sign, sign_alg=sign_alg)
return self.do_logout(name_id, entity_ids, reason, expire, sign, sign_alg=sign_alg, digest_alg=digest_alg)
def do_logout(self, name_id, entity_ids, reason, expire, sign=None,
expected_binding=None, sign_alg=None, digest_alg=None, **kwargs):
@@ -232,7 +232,7 @@ class Saml2Client(Base):
key = kwargs.get("key", self.signkey)
srequest = str(request)
else:
srequest = self.sign(request, sign_alg=sign_alg)
srequest = self.sign(request, sign_alg=sign_alg, digest_alg=digest_alg)
else:
srequest = str(request)
@@ -316,7 +316,7 @@ class Saml2Client(Base):
return self.do_logout(decode(status["name_id"]),
status["entity_ids"],
status["reason"], status["not_on_or_after"],
status["sign"], sign_alg=sign_alg)
status["sign"], sign_alg=sign_alg, digest_alg=digest_alg)
def _use_soap(self, destination, query_type, **kwargs):
_create_func = getattr(self, "create_%s" % query_type)

View File

@@ -339,11 +339,11 @@ class Base(Entity):
return self._message(AuthnRequest, destination, message_id,
consent, extensions, sign, sign_prepare,
protocol_binding=binding,
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, **args)
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg, **args)
return self._message(AuthnRequest, destination, message_id, consent,
extensions, sign, sign_prepare,
protocol_binding=binding,
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, **args)
scoping=scoping, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg, **args)
def create_attribute_query(self, destination, name_id=None,
attribute=None, message_id=0, consent=None,
@@ -404,7 +404,7 @@ class Base(Entity):
return self._message(AttributeQuery, destination, message_id, consent,
extensions, sign, sign_prepare, subject=subject,
attribute=attribute, nsprefix=nsprefix, sign_alg=sign_alg)
attribute=attribute, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
# MUST use SOAP for
# AssertionIDRequest, SubjectQuery,
@@ -430,7 +430,7 @@ class Base(Entity):
return self._message(AuthzDecisionQuery, destination, message_id,
consent, extensions, sign, action=action,
evidence=evidence, resource=resource,
subject=subject, sign_alg=sign_alg, **kwargs)
subject=subject, sign_alg=sign_alg, digest_alg=digest_alg, **kwargs)
def create_authz_decision_query_using_assertion(self, destination,
assertion, action=None,
@@ -499,7 +499,7 @@ class Base(Entity):
extensions, sign, subject=subject,
session_index=session_index,
requested_authn_context=authn_context,
nsprefix=nsprefix, sign_alg=sign_alg)
nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
def create_name_id_mapping_request(self, name_id_policy,
name_id=None, base_id=None,
@@ -528,17 +528,17 @@ class Base(Entity):
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy, name_id=name_id,
nsprefix=nsprefix, sign_alg=sign_alg)
nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
elif base_id:
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy, base_id=base_id,
nsprefix=nsprefix, sign_alg=sign_alg)
nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
else:
return self._message(NameIDMappingRequest, destination, message_id,
consent, extensions, sign,
name_id_policy=name_id_policy,
encrypted_id=encrypted_id, nsprefix=nsprefix, sign_alg=sign_alg)
encrypted_id=encrypted_id, nsprefix=nsprefix, sign_alg=sign_alg, digest_alg=digest_alg)
# ======== response handling ===========

View File

@@ -775,7 +775,7 @@ def entities_descriptor(eds, valid_for, name, ident, sign, secc, sign_alg=None,
raise SAMLError("If you want to do signing you should define " +
"where your public key are")
entities.signature = pre_signature_part(ident, secc.my_cert, 1, sign_alg=sign_alg)
entities.signature = pre_signature_part(ident, secc.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
entities.id = ident
xmldoc = secc.sign_statement("%s" % entities, class_name(entities))
entities = md.entities_descriptor_from_string(xmldoc)
@@ -797,7 +797,7 @@ def sign_entity_descriptor(edesc, ident, secc, sign_alg=None, digest_alg=None):
if not ident:
ident = sid()
edesc.signature = pre_signature_part(ident, secc.my_cert, 1, sign_alg=sign_alg)
edesc.signature = pre_signature_part(ident, secc.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
edesc.id = ident
xmldoc = secc.sign_statement("%s" % edesc, class_name(edesc))
edesc = md.entity_descriptor_from_string(xmldoc)

View File

@@ -399,7 +399,7 @@ class Server(Entity):
if not encrypt_assertion:
if sign_assertion:
assertion.signature = pre_signature_part(assertion.id, self.sec.my_cert, 1,
sign_alg=sign_alg)
sign_alg=sign_alg, digest_alg=digest_alg)
to_sign.append((class_name(assertion), assertion.id))
#if not encrypted_advice_attributes:
@@ -429,7 +429,7 @@ class Server(Entity):
encrypt_assertion_self_contained=encrypt_assertion_self_contained,
encrypted_advice_attributes=encrypted_advice_attributes,
sign_assertion=sign_assertion,
pefim=pefim, sign_alg=sign_alg,
pefim=pefim, sign_alg=sign_alg, digest_alg=digest_alg,
**args)
# ------------------------------------------------------------------------
@@ -489,14 +489,14 @@ class Server(Entity):
if sign_assertion:
assertion.signature = pre_signature_part(assertion.id,
self.sec.my_cert, 1, sign_alg=sign_alg)
self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
# Just the assertion or the response and the assertion ?
to_sign = [(class_name(assertion), assertion.id)]
args["assertion"] = assertion
return self._response(in_response_to, destination, status, issuer,
sign_response, to_sign, sign_alg=sign_alg, **args)
sign_response, to_sign, sign_alg=sign_alg, digest_alg=digest_alg, **args)
# ------------------------------------------------------------------------
@@ -648,7 +648,7 @@ class Server(Entity):
encrypt_cert_advice=encrypt_cert_advice,
encrypt_cert_assertion=encrypt_cert_assertion,
pefim=pefim,
sign_alg=sign_alg)
sign_alg=sign_alg, digest_alg=digest_alg)
return self._authn_response(in_response_to, # in_response_to
destination, # consumer_url
sp_entity_id, # sp_entity_id
@@ -666,7 +666,7 @@ class Server(Entity):
encrypt_cert_advice=encrypt_cert_advice,
encrypt_cert_assertion=encrypt_cert_assertion,
pefim=pefim,
sign_alg=sign_alg)
sign_alg=sign_alg, digest_alg=digest_alg)
except MissingValue as exc:
return self.create_error_response(in_response_to, destination,
@@ -703,7 +703,7 @@ class Server(Entity):
if to_sign:
if assertion.signature is None:
assertion.signature = pre_signature_part(assertion.id,
self.sec.my_cert, 1, sign_alg=sign_alg)
self.sec.my_cert, 1, sign_alg=sign_alg, digest_alg=digest_alg)
return signed_instance_factory(assertion, self.sec, to_sign)
else:
@@ -735,7 +735,7 @@ class Server(Entity):
in_response_to=in_response_to, **ms_args)
if sign_response:
return self.sign(_resp, sign_alg=sign_alg)
return self.sign(_resp, sign_alg=sign_alg, digest_alg=digest_alg)
else:
logger.info("Message: %s", _resp)
return _resp
@@ -764,7 +764,7 @@ class Server(Entity):
args = {}
return self._response(in_response_to, "", status, issuer,
sign_response, to_sign=[], sign_alg=sign_alg, **args)
sign_response, to_sign=[], sign_alg=sign_alg, digest_alg=digest_alg, **args)
# ---------

View File

@@ -1779,7 +1779,7 @@ class SecurityContext(object):
sid = item.id
if not item.signature:
item.signature = pre_signature_part(sid, self.cert_file, sign_alg=sign_alg)
item.signature = pre_signature_part(sid, self.cert_file, sign_alg=sign_alg, digest_alg=digest_alg)
statement = self.sign_statement(statement, class_name(item),
key=key, key_file=key_file,
@@ -1922,7 +1922,7 @@ def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None,
issue_instant=instant())
if sign:
response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg)
response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg, digest_alg=digest_alg)
if encrypt:
pass

View File

@@ -61,8 +61,8 @@ TRANSFORM_XPATH = 'http://www.w3.org/TR/1999/REC-xpath-19991116'
TRANSFORM_ENVELOPED = 'http://www.w3.org/2000/09/xmldsig#enveloped-signature'
class DefaultSignature:
class _DefaultSignature:
class DefaultSignature(object):
class _DefaultSignature(object):
def __init__(self, sign_alg=None, digest_alg=None):
if sign_alg is None:
self.sign_alg = sig_default

View File

@@ -142,12 +142,12 @@ class TestSignedResponse():
sign_response=True,
sign_assertion=True,
sign_alg=ds.SIG_RSA_SHA256,
digest_alg=ds.DIGEST_SHA512
digest_alg=ds.DIGEST_SHA256
)
sresponse = response_from_string(signed_resp)
assert ds.SIG_RSA_SHA256 in str(sresponse), "Not correctly signed!"
assert ds.DIGEST_SHA512 in str(sresponse), "Not correctly signed!"
assert ds.DIGEST_SHA256 in str(sresponse), "Not correctly signed!"
valid = self.server.sec.verify_signature(signed_resp,
self.server.config.cert_file,
node_name='urn:oasis:names:tc:SAML:2.0:protocol:Response',
@@ -155,7 +155,7 @@ class TestSignedResponse():
id_attr="")
assert valid
assert ds.SIG_RSA_SHA256 in str(sresponse.assertion[0]), "Not correctly signed!"
assert ds.DIGEST_SHA512 in str(sresponse.assertion[0]), "Not correctly signed!"
assert ds.DIGEST_SHA256 in str(sresponse.assertion[0]), "Not correctly signed!"
valid = self.server.sec.verify_signature(signed_resp,
self.server.config.cert_file,
node_name='urn:oasis:names:tc:SAML:2.0:assertion:Assertion',