Added support for signing/verifying messages when using the HTTP-Redirect binding.
This commit is contained in:
@@ -208,8 +208,8 @@ class Entity(HTTPBase):
|
||||
if not seid:
|
||||
seid = sid(self.seed)
|
||||
|
||||
return {"id":seid, "version":VERSION,
|
||||
"issue_instant":instant(), "issuer":self._issuer()}
|
||||
return {"id": seid, "version": VERSION,
|
||||
"issue_instant": instant(), "issuer": self._issuer()}
|
||||
|
||||
def response_args(self, message, bindings=None, descr_type=""):
|
||||
info = {"in_response_to": message.id}
|
||||
@@ -278,10 +278,9 @@ class Entity(HTTPBase):
|
||||
:param text: The SOAP message
|
||||
:return: A dictionary with two keys "body" and "header"
|
||||
"""
|
||||
return class_instances_from_soap_enveloped_saml_thingies(text,
|
||||
[paos,
|
||||
ecp,
|
||||
samlp])
|
||||
return class_instances_from_soap_enveloped_saml_thingies(text, [paos,
|
||||
ecp,
|
||||
samlp])
|
||||
|
||||
def unpack_soap_message(self, text):
|
||||
"""
|
||||
@@ -367,8 +366,8 @@ class Entity(HTTPBase):
|
||||
_issuer = self._issuer(issuer)
|
||||
|
||||
response = response_factory(issuer=_issuer,
|
||||
in_response_to = in_response_to,
|
||||
status = status)
|
||||
in_response_to=in_response_to,
|
||||
status=status)
|
||||
|
||||
if consumer_url:
|
||||
response.destination = consumer_url
|
||||
@@ -377,7 +376,7 @@ class Entity(HTTPBase):
|
||||
setattr(response, key, val)
|
||||
|
||||
if sign:
|
||||
self.sign(response,to_sign=to_sign)
|
||||
self.sign(response, to_sign=to_sign)
|
||||
elif to_sign:
|
||||
return signed_instance_factory(response, self.sec, to_sign)
|
||||
else:
|
||||
@@ -689,8 +688,8 @@ class Entity(HTTPBase):
|
||||
if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]:
|
||||
try:
|
||||
# expected return address
|
||||
kwargs["return_addr"] = self.config.endpoint(service,
|
||||
binding=binding)[0]
|
||||
kwargs["return_addr"] = self.config.endpoint(
|
||||
service, binding=binding)[0]
|
||||
except Exception:
|
||||
logger.info("Not supposed to handle this!")
|
||||
return None
|
||||
|
@@ -23,12 +23,15 @@ Bindings normally consists of three parts:
|
||||
- how to package the information
|
||||
- which protocol to use
|
||||
"""
|
||||
import hashlib
|
||||
import urlparse
|
||||
import saml2
|
||||
import base64
|
||||
import urllib
|
||||
from saml2.s_utils import deflate_and_base64_encode
|
||||
from saml2.s_utils import deflate_and_base64_encode, Unsupported
|
||||
import logging
|
||||
import M2Crypto
|
||||
from saml2.sigver import RSA_SHA1, rsa_load, x509_rsa_loads, pem_format
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +54,9 @@ FORM_SPEC = """<form method="post" action="%s">
|
||||
<input type="hidden" name="RelayState" value="%s" />
|
||||
</form>"""
|
||||
|
||||
def http_form_post_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
|
||||
def http_form_post_message(message, location, relay_state="",
|
||||
typ="SAMLRequest"):
|
||||
"""The HTTP POST binding defines a mechanism by which SAML protocol
|
||||
messages may be transmitted within the base64-encoded content of a
|
||||
HTML form control.
|
||||
@@ -93,7 +98,47 @@ def http_form_post_message(message, location, relay_state="", typ="SAMLRequest")
|
||||
# """
|
||||
# return {"headers": [("Content-type", "text/xml")], "data": message}
|
||||
|
||||
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
|
||||
class BadSignature(Exception):
|
||||
"""The signature is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
def sha1_digest(msg):
|
||||
return hashlib.sha1(msg).digest()
|
||||
|
||||
|
||||
class Signer(object):
|
||||
"""Abstract base class for signing algorithms."""
|
||||
def sign(self, msg, key):
|
||||
"""Sign ``msg`` with ``key`` and return the signature."""
|
||||
raise NotImplementedError
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
"""Return True if ``sig`` is a valid signature for ``msg``."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RSASigner(Signer):
|
||||
def __init__(self, digest, algo):
|
||||
self.digest = digest
|
||||
self.algo = algo
|
||||
|
||||
def sign(self, msg, key):
|
||||
return key.sign(self.digest(msg), self.algo)
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
try:
|
||||
return key.verify(self.digest(msg), sig, self.algo)
|
||||
except M2Crypto.RSA.RSAError, e:
|
||||
raise BadSignature(e)
|
||||
|
||||
|
||||
REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"]
|
||||
RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"]
|
||||
|
||||
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest",
|
||||
sigalg=None, key=None):
|
||||
"""The HTTP Redirect binding defines a mechanism by which SAML protocol
|
||||
messages can be transmitted within URL parameters.
|
||||
Messages are encoded for use with this binding using a URL encoding
|
||||
@@ -104,13 +149,21 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
:param message: The message
|
||||
:param location: Where the message should be posted to
|
||||
:param relay_state: for preserving and conveying state information
|
||||
:param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart
|
||||
:param sigalg: The signature algorithm to use.
|
||||
:param key: Key to use for signing
|
||||
:return: A tuple containing header information and a HTML message.
|
||||
"""
|
||||
|
||||
if not isinstance(message, basestring):
|
||||
message = "%s" % (message,)
|
||||
|
||||
_order = None
|
||||
if typ in ["SAMLRequest", "SAMLResponse"]:
|
||||
if typ == "SAMLRequest":
|
||||
_order = REQ_ORDER
|
||||
else:
|
||||
_order = RESP_ORDER
|
||||
args = {typ: deflate_and_base64_encode(message)}
|
||||
elif typ == "SAMLart":
|
||||
args = {typ: message}
|
||||
@@ -120,16 +173,67 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
if relay_state:
|
||||
args["RelayState"] = relay_state
|
||||
|
||||
if sigalg:
|
||||
# sigalgs
|
||||
# http://www.w3.org/2000/09/xmldsig#dsa-sha1
|
||||
# http://www.w3.org/2000/09/xmldsig#rsa-sha1
|
||||
|
||||
args["SigAlg"] = sigalg
|
||||
|
||||
if sigalg == RSA_SHA1:
|
||||
signer = RSASigner(sha1_digest, "sha1")
|
||||
string = "&".join([urllib.urlencode({k: args[k]}) for k in _order])
|
||||
args["Signature"] = base64.b64encode(signer.sign(string, key))
|
||||
string = urllib.urlencode(args)
|
||||
else:
|
||||
raise Unsupported("Signing algorithm")
|
||||
else:
|
||||
string = urllib.urlencode(args)
|
||||
|
||||
glue_char = "&" if urlparse.urlparse(location).query else "?"
|
||||
login_url = glue_char.join([location, urllib.urlencode(args)])
|
||||
login_url = glue_char.join([location, string])
|
||||
headers = [('Location', login_url)]
|
||||
body = []
|
||||
|
||||
return {"headers":headers, "data":body}
|
||||
return {"headers": headers, "data": body}
|
||||
|
||||
|
||||
def verify_redirect_signature(info, cert):
|
||||
"""
|
||||
|
||||
:param info: A dictionary as produced by parse_qs, means all values are
|
||||
lists.
|
||||
:param cert: A certificate to use when verifying the signature
|
||||
:return: True, if signature verified
|
||||
"""
|
||||
|
||||
if info["SigAlg"][0] == RSA_SHA1:
|
||||
if "SAMLRequest" in info:
|
||||
_order = REQ_ORDER
|
||||
elif "SAMLResponse" in info:
|
||||
_order = RESP_ORDER
|
||||
else:
|
||||
raise Unsupported(
|
||||
"Verifying signature on something that should not be signed")
|
||||
signer = RSASigner(sha1_digest, "sha1")
|
||||
args = info.copy()
|
||||
del args["Signature"] # everything but the signature
|
||||
string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order])
|
||||
_key = x509_rsa_loads(pem_format(cert))
|
||||
_sign = base64.b64decode(info["Signature"][0])
|
||||
try:
|
||||
signer.verify(string, _sign, _key)
|
||||
return True
|
||||
except BadSignature:
|
||||
return False
|
||||
else:
|
||||
raise Unsupported("Signature algorithm: %s" % info["SigAlg"])
|
||||
|
||||
|
||||
DUMMY_NAMESPACE = "http://example.org/"
|
||||
PREFIX = '<?xml version="1.0" encoding="UTF-8"?>'
|
||||
|
||||
|
||||
def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
|
||||
""" Returns a soap envelope containing a SAML request
|
||||
as a text string.
|
||||
@@ -170,21 +274,24 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
|
||||
cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1]
|
||||
_str = _str.replace(cut1, "")
|
||||
first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],))
|
||||
last = _str.find(">", first+14)
|
||||
cut2 = _str[first:last+1]
|
||||
last = _str.find(">", first + 14)
|
||||
cut2 = _str[first:last + 1]
|
||||
return _str.replace(cut2, thingy)
|
||||
else:
|
||||
thingy.become_child_element_of(body)
|
||||
return ElementTree.tostring(envelope, encoding="UTF-8")
|
||||
|
||||
|
||||
def http_soap_message(message):
|
||||
return {"headers": [("Content-type", "application/soap+xml")],
|
||||
"data": make_soap_enveloped_saml_thingy(message)}
|
||||
|
||||
|
||||
def http_paos(message, extra=None):
|
||||
return {"headers":[("Content-type", "application/soap+xml")],
|
||||
return {"headers": [("Content-type", "application/soap+xml")],
|
||||
"data": make_soap_enveloped_saml_thingy(message, extra)}
|
||||
|
||||
|
||||
def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
||||
"""Parses a SOAP enveloped SAML thing and returns header parts and body
|
||||
|
||||
@@ -205,7 +312,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
||||
body = saml2.create_class_from_element_tree(body_class, sub)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Wrong body type (%s) in SOAP envelope" % sub.tag)
|
||||
"Wrong body type (%s) in SOAP envelope" % sub.tag)
|
||||
elif part.tag == '{%s}Header' % NAMESPACE:
|
||||
if not header_class:
|
||||
raise Exception("Header where I didn't expect one")
|
||||
@@ -226,13 +333,15 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
||||
PACKING = {
|
||||
saml2.BINDING_HTTP_REDIRECT: http_redirect_message,
|
||||
saml2.BINDING_HTTP_POST: http_form_post_message,
|
||||
}
|
||||
}
|
||||
|
||||
def packager( identifier ):
|
||||
|
||||
def packager(identifier):
|
||||
try:
|
||||
return PACKING[identifier]
|
||||
except KeyError:
|
||||
raise Exception("Unkown binding type: %s" % identifier)
|
||||
|
||||
|
||||
def factory(binding, message, location, relay_state="", typ="SAMLRequest"):
|
||||
return PACKING[binding](message, location, relay_state, typ)
|
@@ -9,7 +9,7 @@ import sys
|
||||
import hmac
|
||||
|
||||
# from python 2.5
|
||||
if sys.version_info >= (2,5):
|
||||
if sys.version_info >= (2, 5):
|
||||
import hashlib
|
||||
else: # before python 2.5
|
||||
import sha
|
||||
@@ -27,36 +27,51 @@ import zlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SamlException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestVersionTooLow(SamlException):
|
||||
pass
|
||||
|
||||
|
||||
class RequestVersionTooHigh(SamlException):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownPrincipal(SamlException):
|
||||
pass
|
||||
|
||||
class UnsupportedBinding(SamlException):
|
||||
|
||||
class Unsupported(SamlException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedBinding(Unsupported):
|
||||
pass
|
||||
|
||||
|
||||
class VersionMismatch(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Unknown(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OtherError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingValue(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PolicyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadRequest(Exception):
|
||||
pass
|
||||
|
||||
@@ -73,28 +88,29 @@ EXCEPTION2STATUS = {
|
||||
Exception: samlp.STATUS_AUTHN_FAILED,
|
||||
}
|
||||
|
||||
GENERIC_DOMAINS = "aero", "asia", "biz", "cat", "com", "coop", \
|
||||
"edu", "gov", "info", "int", "jobs", "mil", "mobi", "museum", \
|
||||
"name", "net", "org", "pro", "tel", "travel"
|
||||
GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu",
|
||||
"gov", "info", "int", "jobs", "mil", "mobi", "museum",
|
||||
"name", "net", "org", "pro", "tel", "travel"]
|
||||
|
||||
def valid_email(emailaddress, domains = GENERIC_DOMAINS):
|
||||
|
||||
def valid_email(emailaddress, domains=GENERIC_DOMAINS):
|
||||
"""Checks for a syntactically valid email address."""
|
||||
|
||||
# Email address must be at least 6 characters in total.
|
||||
# Assuming noone may have addresses of the type a@com
|
||||
if len(emailaddress) < 6:
|
||||
return False # Address too short.
|
||||
return False # Address too short.
|
||||
|
||||
# Split up email address into parts.
|
||||
try:
|
||||
localpart, domainname = emailaddress.rsplit('@', 1)
|
||||
host, toplevel = domainname.rsplit('.', 1)
|
||||
except ValueError:
|
||||
return False # Address does not have enough parts.
|
||||
return False # Address does not have enough parts.
|
||||
|
||||
# Check for Country code or Generic Domain.
|
||||
if len(toplevel) != 2 and toplevel not in domains:
|
||||
return False # Not a domain name.
|
||||
return False # Not a domain name.
|
||||
|
||||
for i in '-_.%+.':
|
||||
localpart = localpart.replace(i, "")
|
||||
@@ -102,27 +118,30 @@ def valid_email(emailaddress, domains = GENERIC_DOMAINS):
|
||||
host = host.replace(i, "")
|
||||
|
||||
if localpart.isalnum() and host.isalnum():
|
||||
return True # Email address is fine.
|
||||
return True # Email address is fine.
|
||||
else:
|
||||
return False # Email address has funny characters.
|
||||
return False # Email address has funny characters.
|
||||
|
||||
def decode_base64_and_inflate( string ):
|
||||
|
||||
def decode_base64_and_inflate(string):
|
||||
""" base64 decodes and then inflates according to RFC1951
|
||||
|
||||
:param string: a deflated and encoded string
|
||||
:return: the string after decoding and inflating
|
||||
"""
|
||||
|
||||
return zlib.decompress( base64.b64decode( string ) , -15)
|
||||
return zlib.decompress(base64.b64decode(string), -15)
|
||||
|
||||
def deflate_and_base64_encode( string_val ):
|
||||
|
||||
def deflate_and_base64_encode(string_val):
|
||||
"""
|
||||
Deflates and the base64 encodes a string
|
||||
|
||||
:param string_val: The string to deflate and encode
|
||||
:return: The deflated and encoded string
|
||||
"""
|
||||
return base64.b64encode( zlib.compress( string_val )[2:-4] )
|
||||
return base64.b64encode(zlib.compress(string_val)[2:-4])
|
||||
|
||||
|
||||
def rndstr(size=16):
|
||||
"""
|
||||
@@ -134,9 +153,11 @@ def rndstr(size=16):
|
||||
_basech = string.ascii_letters + string.digits
|
||||
return "".join([random.choice(_basech) for _ in range(size)])
|
||||
|
||||
|
||||
def sid(seed=""):
|
||||
"""The hash of the server time + seed makes an unique SID for each session.
|
||||
128-bits long so it fulfills the SAML2 requirements which states 128-160 bits
|
||||
128-bits long so it fulfills the SAML2 requirements which states
|
||||
128-160 bits
|
||||
|
||||
:param seed: A seed string
|
||||
:return: The hex version of the digest, prefixed by 'id-' to make it
|
||||
@@ -146,7 +167,8 @@ def sid(seed=""):
|
||||
ident.update(repr(time.time()))
|
||||
if seed:
|
||||
ident.update(seed)
|
||||
return "id-"+ident.hexdigest()
|
||||
return "id-" + ident.hexdigest()
|
||||
|
||||
|
||||
def parse_attribute_map(filenames):
|
||||
"""
|
||||
@@ -168,6 +190,7 @@ def parse_attribute_map(filenames):
|
||||
|
||||
return forward, backward
|
||||
|
||||
|
||||
def identity_attribute(form, attribute, forward_map=None):
|
||||
if form == "friendly":
|
||||
if attribute.friendly_name:
|
||||
@@ -182,6 +205,7 @@ def identity_attribute(form, attribute, forward_map=None):
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def error_status_factory(info):
|
||||
if isinstance(info, Exception):
|
||||
try:
|
||||
@@ -194,39 +218,38 @@ def error_status_factory(info):
|
||||
status_code=samlp.StatusCode(
|
||||
value=samlp.STATUS_RESPONDER,
|
||||
status_code=samlp.StatusCode(
|
||||
value=exc_val)
|
||||
),
|
||||
)
|
||||
value=exc_val)))
|
||||
else:
|
||||
(errcode, text) = info
|
||||
status = samlp.Status(
|
||||
status_message=samlp.StatusMessage(text=text),
|
||||
status_code=samlp.StatusCode(
|
||||
value=samlp.STATUS_RESPONDER,
|
||||
status_code=samlp.StatusCode(value=errcode)
|
||||
),
|
||||
)
|
||||
status_code=samlp.StatusCode(value=errcode)))
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def success_status_factory():
|
||||
return samlp.Status(status_code=samlp.StatusCode(
|
||||
value=samlp.STATUS_SUCCESS))
|
||||
value=samlp.STATUS_SUCCESS))
|
||||
|
||||
|
||||
def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
|
||||
return samlp.Status(
|
||||
status_message=samlp.StatusMessage(text=message),
|
||||
status_code=samlp.StatusCode(
|
||||
value=fro,
|
||||
status_code=samlp.StatusCode(value=code)))
|
||||
status_code=samlp.StatusCode(value=fro,
|
||||
status_code=samlp.StatusCode(value=code)))
|
||||
|
||||
|
||||
def assertion_factory(**kwargs):
|
||||
assertion = saml.Assertion(version=VERSION, id=sid(),
|
||||
issue_instant=instant())
|
||||
issue_instant=instant())
|
||||
for key, val in kwargs.items():
|
||||
setattr(assertion, key, val)
|
||||
return assertion
|
||||
|
||||
|
||||
def _attrval(val, typ=""):
|
||||
if isinstance(val, list) or isinstance(val, set):
|
||||
attrval = [saml.AttributeValue(text=v) for v in val]
|
||||
@@ -246,6 +269,7 @@ def _attrval(val, typ=""):
|
||||
# xmlns:xs="http://www.w3.org/2001/XMLSchema"
|
||||
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
|
||||
|
||||
def do_ava(val, typ=""):
|
||||
if isinstance(val, basestring):
|
||||
ava = saml.AttributeValue()
|
||||
@@ -253,7 +277,7 @@ def do_ava(val, typ=""):
|
||||
attrval = [ava]
|
||||
elif isinstance(val, list):
|
||||
attrval = [do_ava(v)[0] for v in val]
|
||||
elif val or val == False:
|
||||
elif val or val is False:
|
||||
ava = saml.AttributeValue()
|
||||
ava.set_text(val)
|
||||
attrval = [ava]
|
||||
@@ -268,6 +292,7 @@ def do_ava(val, typ=""):
|
||||
|
||||
return attrval
|
||||
|
||||
|
||||
def do_attribute(val, typ, key):
|
||||
attr = saml.Attribute()
|
||||
attrval = do_ava(val, typ)
|
||||
@@ -276,7 +301,7 @@ def do_attribute(val, typ, key):
|
||||
|
||||
if isinstance(key, basestring):
|
||||
attr.name = key
|
||||
elif isinstance(key, tuple): # 3-tuple or 2-tuple
|
||||
elif isinstance(key, tuple): # 3-tuple or 2-tuple
|
||||
try:
|
||||
(name, nformat, friendly) = key
|
||||
except ValueError:
|
||||
@@ -290,6 +315,7 @@ def do_attribute(val, typ, key):
|
||||
attr.friendly_name = friendly
|
||||
return attr
|
||||
|
||||
|
||||
def do_attributes(identity):
|
||||
attrs = []
|
||||
if not identity:
|
||||
@@ -308,6 +334,7 @@ def do_attributes(identity):
|
||||
attrs.append(attr)
|
||||
return attrs
|
||||
|
||||
|
||||
def do_attribute_statement(identity):
|
||||
"""
|
||||
:param identity: A dictionary with fiendly names as keys
|
||||
@@ -315,12 +342,14 @@ def do_attribute_statement(identity):
|
||||
"""
|
||||
return saml.AttributeStatement(attribute=do_attributes(identity))
|
||||
|
||||
|
||||
def factory(klass, **kwargs):
|
||||
instance = klass()
|
||||
for key, val in kwargs.items():
|
||||
setattr(instance, key, val)
|
||||
return instance
|
||||
|
||||
|
||||
def signature(secret, parts):
|
||||
"""Generates a signature.
|
||||
"""
|
||||
@@ -334,6 +363,7 @@ def signature(secret, parts):
|
||||
|
||||
return csum.hexdigest()
|
||||
|
||||
|
||||
def verify_signature(secret, parts):
|
||||
""" Checks that the signature is correct """
|
||||
if signature(secret, parts[:-1]) == parts[-1]:
|
||||
@@ -344,9 +374,10 @@ def verify_signature(secret, parts):
|
||||
|
||||
FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#"
|
||||
|
||||
|
||||
def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
|
||||
"""
|
||||
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value ) '#'
|
||||
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value) '#'
|
||||
Allowed attributes:
|
||||
TS the login time stamp
|
||||
RP the relying party entityID
|
||||
|
@@ -25,7 +25,9 @@ import logging
|
||||
import random
|
||||
import os
|
||||
import sys
|
||||
from time import mktime
|
||||
import M2Crypto
|
||||
from M2Crypto.X509 import load_cert_string
|
||||
from saml2.samlp import Response
|
||||
|
||||
import xmldsig as ds
|
||||
@@ -38,7 +40,7 @@ from saml2 import VERSION
|
||||
|
||||
from saml2.s_utils import sid
|
||||
|
||||
from saml2.time_util import instant
|
||||
from saml2.time_util import instant, utc_now, str_to_time
|
||||
|
||||
from tempfile import NamedTemporaryFile
|
||||
from subprocess import Popen, PIPE
|
||||
@@ -47,6 +49,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature")
|
||||
|
||||
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
|
||||
|
||||
|
||||
def signed(item):
|
||||
if SIG in item.c_children.keys() and item.signature:
|
||||
return True
|
||||
@@ -62,6 +67,7 @@ def signed(item):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_xmlsec_binary(paths=None):
|
||||
"""
|
||||
Tries to find the xmlsec1 binary.
|
||||
@@ -75,7 +81,7 @@ def get_xmlsec_binary(paths=None):
|
||||
bin_name = "xmlsec1"
|
||||
elif os.name == "nt":
|
||||
bin_name = "xmlsec1.exe"
|
||||
else: # Default !?
|
||||
else: # Default !?
|
||||
bin_name = "xmlsec1"
|
||||
|
||||
if paths:
|
||||
@@ -109,47 +115,36 @@ ENC_KEY_CLASS = "EncryptedKey"
|
||||
|
||||
_TEST_ = True
|
||||
|
||||
|
||||
class SignatureError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class XmlsecError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingKey(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DecryptError(Exception):
|
||||
pass
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
#def make_signed_instance(klass, spec, seccont, base64encode=False):
|
||||
# """ Will only return signed instance if the signature
|
||||
# preamble is present
|
||||
#
|
||||
# :param klass: The type of class the instance should be
|
||||
# :param spec: The specification of attributes and children of the class
|
||||
# :param seccont: The security context (instance of SecurityContext)
|
||||
# :param base64encode: Whether the attribute values should be base64 encoded
|
||||
# :return: A signed (or possibly unsigned) instance of the class
|
||||
# """
|
||||
# if "signature" in spec:
|
||||
# signed_xml = seccont.sign_statement_using_xmlsec("%s" % instance,
|
||||
# class_name(instance), instance.id)
|
||||
# return create_class_from_xml_string(instance.__class__, signed_xml)
|
||||
# else:
|
||||
# return make_instance(klass, spec, base64encode)
|
||||
|
||||
def xmlsec_version(execname):
|
||||
com_list = [execname,"--version"]
|
||||
com_list = [execname, "--version"]
|
||||
pof = Popen(com_list, stderr=PIPE, stdout=PIPE)
|
||||
try:
|
||||
return pof.stdout.read().split(" ")[1]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
||||
base64encode=False, elements_to_sign=None):
|
||||
base64encode=False, elements_to_sign=None):
|
||||
"""
|
||||
Creates a class instance with a specified value, the specified
|
||||
class instance may be a value on a property in a defined class instance.
|
||||
@@ -169,14 +164,14 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
||||
|
||||
if isinstance(val, dict):
|
||||
cinst = _instance(klass, val, seccont, base64encode=base64encode,
|
||||
elements_to_sign=elements_to_sign)
|
||||
elements_to_sign=elements_to_sign)
|
||||
else:
|
||||
try:
|
||||
cinst = klass().set_text(val)
|
||||
except ValueError:
|
||||
if not part:
|
||||
cis = [_make_vals(sval, klass, seccont, klass_inst, prop,
|
||||
True, base64encode, elements_to_sign) for sval in val]
|
||||
True, base64encode, elements_to_sign) for sval in val]
|
||||
setattr(klass_inst, prop, cis)
|
||||
else:
|
||||
raise
|
||||
@@ -188,6 +183,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
||||
cis = [cinst]
|
||||
setattr(klass_inst, prop, cis)
|
||||
|
||||
|
||||
def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
||||
instance = klass()
|
||||
|
||||
@@ -208,30 +204,31 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
||||
#print "## %s, %s" % (prop, klassdef)
|
||||
if prop in ava:
|
||||
#print "### %s" % ava[prop]
|
||||
if isinstance(klassdef, list): # means there can be a list of values
|
||||
if isinstance(klassdef, list):
|
||||
# means there can be a list of values
|
||||
_make_vals(ava[prop], klassdef[0], seccont, instance, prop,
|
||||
base64encode=base64encode,
|
||||
elements_to_sign=elements_to_sign)
|
||||
base64encode=base64encode,
|
||||
elements_to_sign=elements_to_sign)
|
||||
else:
|
||||
cis = _make_vals(ava[prop], klassdef, seccont, instance, prop,
|
||||
True, base64encode, elements_to_sign)
|
||||
True, base64encode, elements_to_sign)
|
||||
setattr(instance, prop, cis)
|
||||
|
||||
if "extension_elements" in ava:
|
||||
for item in ava["extension_elements"]:
|
||||
instance.extension_elements.append(
|
||||
ExtensionElement(item["tag"]).loadd(item))
|
||||
ExtensionElement(item["tag"]).loadd(item))
|
||||
|
||||
if "extension_attributes" in ava:
|
||||
for key, val in ava["extension_attributes"].items():
|
||||
instance.extension_attributes[key] = val
|
||||
|
||||
|
||||
|
||||
if "signature" in ava:
|
||||
elements_to_sign.append((class_name(instance), instance.id))
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
||||
"""
|
||||
|
||||
@@ -243,8 +240,8 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
||||
if elements_to_sign:
|
||||
signed_xml = "%s" % instance
|
||||
for (node_name, nodeid) in elements_to_sign:
|
||||
signed_xml = seccont.sign_statement_using_xmlsec(signed_xml,
|
||||
klass_namn=node_name, nodeid=nodeid)
|
||||
signed_xml = seccont.sign_statement_using_xmlsec(
|
||||
signed_xml, klass_namn=node_name, nodeid=nodeid)
|
||||
|
||||
#print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
#print "%s" % signed_xml
|
||||
@@ -255,6 +252,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_id():
|
||||
""" Create a string of 40 random characters from the set [a-p],
|
||||
can be used as a unique identifier of objects.
|
||||
@@ -266,6 +264,7 @@ def create_id():
|
||||
ret += chr(random.randint(0, 15) + ord('a'))
|
||||
return ret
|
||||
|
||||
|
||||
def make_temp(string, suffix="", decode=True):
|
||||
""" xmlsec needs files in some cases where only strings exist, hence the
|
||||
need for this function. It creates a temporary file with the
|
||||
@@ -288,9 +287,34 @@ def make_temp(string, suffix="", decode=True):
|
||||
ntf.seek(0)
|
||||
return ntf, ntf.name
|
||||
|
||||
|
||||
def split_len(seq, length):
|
||||
return [seq[i:i+length] for i in range(0, len(seq), length)]
|
||||
|
||||
return [seq[i:i + length] for i in range(0, len(seq), length)]
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
M2_TIME_FORMAT = "%b %d %H:%M:%S %Y"
|
||||
|
||||
|
||||
def to_time(_time):
|
||||
assert _time.endswith(" GMT")
|
||||
_time = _time[:-4]
|
||||
return mktime(str_to_time(_time, M2_TIME_FORMAT))
|
||||
|
||||
|
||||
def active_cert(key):
|
||||
cert_str = pem_format(key)
|
||||
certificate = load_cert_string(cert_str)
|
||||
not_before = to_time(str(certificate.get_not_before()))
|
||||
not_after = to_time(str(certificate.get_not_after()))
|
||||
try:
|
||||
assert not_before < utc_now()
|
||||
assert not_after > utc_now()
|
||||
return True
|
||||
except AssertionError:
|
||||
return False
|
||||
|
||||
|
||||
def cert_from_key_info(key_info):
|
||||
""" Get all X509 certs from a KeyInfo instance. Care is taken to make sure
|
||||
that the certs are continues sequences of bytes.
|
||||
@@ -307,11 +331,15 @@ def cert_from_key_info(key_info):
|
||||
#print "X509Data",x509_data
|
||||
x509_certificate = x509_data.x509_certificate
|
||||
cert = x509_certificate.text.strip()
|
||||
cert = "\n".join(split_len("".join([
|
||||
s.strip() for s in cert.split()]),64))
|
||||
res.append(cert)
|
||||
cert = "\n".join(split_len("".join([s.strip() for s in
|
||||
cert.split()]), 64))
|
||||
if active_cert(cert):
|
||||
res.append(cert)
|
||||
else:
|
||||
logger.info("Inactive cert")
|
||||
return res
|
||||
|
||||
|
||||
def cert_from_key_info_dict(key_info):
|
||||
""" Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure
|
||||
that the certs are continues sequences of bytes.
|
||||
@@ -330,11 +358,15 @@ def cert_from_key_info_dict(key_info):
|
||||
for x509_data in key_info["x509_data"]:
|
||||
x509_certificate = x509_data["x509_certificate"]
|
||||
cert = x509_certificate["text"].strip()
|
||||
cert = "\n".join(split_len("".join([
|
||||
s.strip() for s in cert.split()]),64))
|
||||
res.append(cert)
|
||||
cert = "\n".join(split_len("".join([s.strip() for s in
|
||||
cert.split()]), 64))
|
||||
if active_cert(cert):
|
||||
res.append(cert)
|
||||
else:
|
||||
logger.info("Inactive cert")
|
||||
return res
|
||||
|
||||
|
||||
def cert_from_instance(instance):
|
||||
""" Find certificates that are part of an instance
|
||||
|
||||
@@ -350,25 +382,30 @@ def cert_from_instance(instance):
|
||||
from M2Crypto.__m2crypto import bn_to_mpi
|
||||
from M2Crypto.__m2crypto import hex_to_bn
|
||||
|
||||
|
||||
def intarr2long(arr):
|
||||
return long(''.join(["%02x" % byte for byte in arr]), 16)
|
||||
|
||||
|
||||
def dehexlify(bi):
|
||||
s = hexlify(bi)
|
||||
return [int(s[i]+s[i+1], 16) for i in range(0,len(s),2)]
|
||||
return [int(s[i] + s[i + 1], 16) for i in range(0, len(s), 2)]
|
||||
|
||||
|
||||
def long_to_mpi(num):
|
||||
"""Converts a python integer or long to OpenSSL MPInt used by M2Crypto.
|
||||
Borrowed from Snowball.Shared.Crypto"""
|
||||
h = hex(num)[2:] # strip leading 0x in string
|
||||
h = hex(num)[2:] # strip leading 0x in string
|
||||
if len(h) % 2 == 1:
|
||||
h = '0' + h # add leading 0 to get even number of hexdigits
|
||||
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
|
||||
h = '0' + h # add leading 0 to get even number of hexdigits
|
||||
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
|
||||
|
||||
|
||||
def base64_to_long(data):
|
||||
_d = base64.urlsafe_b64decode(data + '==')
|
||||
return intarr2long(dehexlify(_d))
|
||||
|
||||
|
||||
def key_from_key_value(key_info):
|
||||
res = []
|
||||
for value in key_info.key_value:
|
||||
@@ -376,10 +413,11 @@ def key_from_key_value(key_info):
|
||||
e = base64_to_long(value.rsa_key_value.exponent)
|
||||
m = base64_to_long(value.rsa_key_value.modulus)
|
||||
key = M2Crypto.RSA.new_pub_key((long_to_mpi(e),
|
||||
long_to_mpi(m)))
|
||||
long_to_mpi(m)))
|
||||
res.append(key)
|
||||
return res
|
||||
|
||||
|
||||
def key_from_key_value_dict(key_info):
|
||||
res = []
|
||||
if not "key_value" in key_info:
|
||||
@@ -396,10 +434,28 @@ def key_from_key_value_dict(key_info):
|
||||
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def rsa_load(filename):
|
||||
"""Read a PEM-encoded RSA key pair from a file."""
|
||||
return M2Crypto.RSA.load_key(filename, M2Crypto.util.no_passphrase_callback)
|
||||
|
||||
|
||||
def rsa_loads(key):
|
||||
"""Read a PEM-encoded RSA key pair from a string."""
|
||||
return M2Crypto.RSA.load_key_string(key,
|
||||
M2Crypto.util.no_passphrase_callback)
|
||||
|
||||
|
||||
def x509_rsa_loads(string):
|
||||
cert = M2Crypto.X509.load_cert_string(string)
|
||||
return cert.get_pubkey().get_rsa()
|
||||
|
||||
|
||||
def pem_format(key):
|
||||
return "\n".join(["-----BEGIN CERTIFICATE-----",
|
||||
key,"-----END CERTIFICATE-----"])
|
||||
key, "-----END CERTIFICATE-----"])
|
||||
|
||||
|
||||
def parse_xmlsec_output(output):
|
||||
""" Parse the output from xmlsec to try to find out if the
|
||||
command was successfull or not.
|
||||
@@ -416,12 +472,13 @@ def parse_xmlsec_output(output):
|
||||
|
||||
__DEBUG = 0
|
||||
|
||||
LOG_LINE = 60*"="+"\n%s\n"+60*"-"+"\n%s"+60*"="
|
||||
LOG_LINE_2 = 60*"="+"\n%s\n%s\n"+60*"-"+"\n%s"+60*"="
|
||||
LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||
LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||
|
||||
|
||||
def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
|
||||
node_name=NODE_NAME, debug=False, node_id=None,
|
||||
id_attr=""):
|
||||
node_name=NODE_NAME, debug=False, node_id=None,
|
||||
id_attr=""):
|
||||
""" Verifies the signature of a XML document.
|
||||
|
||||
:param enctext: The signed XML document
|
||||
@@ -481,6 +538,7 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_cert_from_file(cert_file, cert_type):
|
||||
""" Reads a certificate from a file. The assumption is that there is
|
||||
only one certificate in the file
|
||||
@@ -516,6 +574,7 @@ def read_cert_from_file(cert_file, cert_type):
|
||||
data = open(cert_file).read()
|
||||
return base64.b64encode(str(data))
|
||||
|
||||
|
||||
def security_context(conf, debug=None):
|
||||
""" Creates a security context based on the configuration
|
||||
|
||||
@@ -535,14 +594,15 @@ def security_context(conf, debug=None):
|
||||
_only_md = False
|
||||
|
||||
return SecurityContext(conf.xmlsec_binary, conf.key_file,
|
||||
cert_file=conf.cert_file, metadata=metadata,
|
||||
debug=debug, only_use_keys_in_metadata=_only_md)
|
||||
cert_file=conf.cert_file, metadata=metadata,
|
||||
debug=debug, only_use_keys_in_metadata=_only_md)
|
||||
|
||||
|
||||
class SecurityContext(object):
|
||||
def __init__(self, xmlsec_binary, key_file="", key_type= "pem",
|
||||
cert_file="", cert_type="pem", metadata=None,
|
||||
debug=False, template="", encrypt_key_type="des-192",
|
||||
only_use_keys_in_metadata=False):
|
||||
def __init__(self, xmlsec_binary, key_file="", key_type="pem",
|
||||
cert_file="", cert_type="pem", metadata=None,
|
||||
debug=False, template="", encrypt_key_type="des-192",
|
||||
only_use_keys_in_metadata=False):
|
||||
|
||||
self.xmlsec = xmlsec_binary
|
||||
|
||||
@@ -592,12 +652,9 @@ class SecurityContext(object):
|
||||
_, fil = make_temp("%s" % text, decode=False)
|
||||
ntf = NamedTemporaryFile()
|
||||
|
||||
com_list = [self.xmlsec, "--encrypt",
|
||||
"--pubkey-pem", recv_key,
|
||||
"--session-key", key_type,
|
||||
"--xml-data", fil,
|
||||
"--output", ntf.name,
|
||||
template]
|
||||
com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key,
|
||||
"--session-key", key_type, "--xml-data", fil,
|
||||
"--output", ntf.name, template]
|
||||
|
||||
logger.debug("Encryption command: %s" % " ".join(com_list))
|
||||
|
||||
@@ -625,11 +682,9 @@ class SecurityContext(object):
|
||||
_, fil = make_temp("%s" % enctext, decode=False)
|
||||
ntf = NamedTemporaryFile()
|
||||
|
||||
com_list = [self.xmlsec, "--decrypt",
|
||||
"--privkey-pem", self.key_file,
|
||||
"--output", ntf.name,
|
||||
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS,
|
||||
fil]
|
||||
com_list = [self.xmlsec, "--decrypt", "--privkey-pem",
|
||||
self.key_file, "--output", ntf.name,
|
||||
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, fil]
|
||||
|
||||
logger.debug("Decrypt command: %s" % " ".join(com_list))
|
||||
|
||||
@@ -646,9 +701,8 @@ class SecurityContext(object):
|
||||
ntf.seek(0)
|
||||
return ntf.read()
|
||||
|
||||
|
||||
def verify_signature(self, enctext, cert_file=None, cert_type="pem",
|
||||
node_name=NODE_NAME, node_id=None, id_attr=""):
|
||||
node_name=NODE_NAME, node_id=None, id_attr=""):
|
||||
""" Verifies the signature of a XML document.
|
||||
|
||||
:param enctext: The XML document as a string
|
||||
@@ -773,22 +827,22 @@ class SecurityContext(object):
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_authn_query(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "authn_query",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_logout_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "logout_request",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_logout_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "logout_response",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_attribute_query(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "attribute_query",
|
||||
must, origdoc)
|
||||
|
||||
@@ -799,31 +853,31 @@ class SecurityContext(object):
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_authz_decision_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"authz_decision_response", must,
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"name_id_mapping_request",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"name_id_mapping_response",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_artifact_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"artifact_request",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_artifact_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"artifact_response",
|
||||
must, origdoc)
|
||||
@@ -835,19 +889,19 @@ class SecurityContext(object):
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_manage_name_id_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"manage_name_id_response", must,
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_assertion_id_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"assertion_id_request", must,
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_assertion_id_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "assertion", must,
|
||||
origdoc)
|
||||
|
||||
@@ -882,18 +936,16 @@ class SecurityContext(object):
|
||||
|
||||
try:
|
||||
self._check_signature(decoded_xml, assertion,
|
||||
class_name(assertion), origdoc)
|
||||
class_name(assertion), origdoc)
|
||||
except Exception, exc:
|
||||
logger.error("correctly_signed_response: %s" % exc)
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
# SIGNATURE PART
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
def sign_statement_using_xmlsec(self, statement, klass_namn, key=None,
|
||||
key_file=None, nodeid=None, id_attr=""):
|
||||
"""Sign a SAML statement using xmlsec.
|
||||
@@ -917,7 +969,6 @@ class SecurityContext(object):
|
||||
|
||||
_, fil = make_temp("%s" % statement, decode=False)
|
||||
|
||||
|
||||
ntf = NamedTemporaryFile()
|
||||
|
||||
com_list = [self.xmlsec, "--sign",
|
||||
@@ -975,10 +1026,8 @@ class SecurityContext(object):
|
||||
:return: The signed statement
|
||||
"""
|
||||
|
||||
return self.sign_statement_using_xmlsec(statement,
|
||||
class_name(samlp.AttributeQuery()),
|
||||
key, key_file, nodeid,
|
||||
id_attr=id_attr)
|
||||
return self.sign_statement_using_xmlsec(statement, class_name(
|
||||
samlp.AttributeQuery()), key, key_file, nodeid, id_attr=id_attr)
|
||||
|
||||
def multiple_signatures(self, statement, to_sign, key=None, key_file=None):
|
||||
"""
|
||||
@@ -991,15 +1040,15 @@ class SecurityContext(object):
|
||||
:param key_file: A file that contains the key to be used
|
||||
:return: A possibly multiple signed statement
|
||||
"""
|
||||
for (item, id, id_attr) in to_sign:
|
||||
if not id:
|
||||
for (item, sid, id_attr) in to_sign:
|
||||
if not sid:
|
||||
if not item.id:
|
||||
id = item.id = sid()
|
||||
sid = item.id = sid()
|
||||
else:
|
||||
id = item.id
|
||||
sid = item.id
|
||||
|
||||
if not item.signature:
|
||||
item.signature = pre_signature_part(id, self.cert_file)
|
||||
item.signature = pre_signature_part(sid, self.cert_file)
|
||||
|
||||
statement = self.sign_statement_using_xmlsec(statement,
|
||||
class_name(item),
|
||||
@@ -1024,32 +1073,30 @@ def pre_signature_part(ident, public_key=None, identifier=None):
|
||||
:return: A preset signature part
|
||||
"""
|
||||
|
||||
signature_method = ds.SignatureMethod(algorithm = ds.SIG_RSA_SHA1)
|
||||
signature_method = ds.SignatureMethod(algorithm=ds.SIG_RSA_SHA1)
|
||||
canonicalization_method = ds.CanonicalizationMethod(
|
||||
algorithm = ds.ALG_EXC_C14N)
|
||||
trans0 = ds.Transform(algorithm = ds.TRANSFORM_ENVELOPED)
|
||||
trans1 = ds.Transform(algorithm = ds.ALG_EXC_C14N)
|
||||
transforms = ds.Transforms(transform = [trans0, trans1])
|
||||
digest_method = ds.DigestMethod(algorithm = ds.DIGEST_SHA1)
|
||||
algorithm=ds.ALG_EXC_C14N)
|
||||
trans0 = ds.Transform(algorithm=ds.TRANSFORM_ENVELOPED)
|
||||
trans1 = ds.Transform(algorithm=ds.ALG_EXC_C14N)
|
||||
transforms = ds.Transforms(transform=[trans0, trans1])
|
||||
digest_method = ds.DigestMethod(algorithm=ds.DIGEST_SHA1)
|
||||
|
||||
reference = ds.Reference(uri = "#%s" % ident,
|
||||
digest_value = ds.DigestValue(),
|
||||
transforms = transforms,
|
||||
digest_method = digest_method)
|
||||
reference = ds.Reference(uri="#%s" % ident, digest_value=ds.DigestValue(),
|
||||
transforms=transforms, digest_method=digest_method)
|
||||
|
||||
signed_info = ds.SignedInfo(signature_method = signature_method,
|
||||
canonicalization_method = canonicalization_method,
|
||||
reference = reference)
|
||||
signed_info = ds.SignedInfo(signature_method=signature_method,
|
||||
canonicalization_method=canonicalization_method,
|
||||
reference=reference)
|
||||
|
||||
signature = ds.Signature(signed_info=signed_info,
|
||||
signature_value=ds.SignatureValue())
|
||||
signature = ds.Signature(signed_info=signed_info,
|
||||
signature_value=ds.SignatureValue())
|
||||
|
||||
if identifier:
|
||||
signature.id = "Signature%d" % identifier
|
||||
|
||||
if public_key:
|
||||
x509_data = ds.X509Data(x509_certificate=[ds.X509Certificate(
|
||||
text=public_key)])
|
||||
x509_data = ds.X509Data(
|
||||
x509_certificate=[ds.X509Certificate(text=public_key)])
|
||||
key_info = ds.KeyInfo(x509_data=x509_data)
|
||||
signature.key_info = key_info
|
||||
|
||||
|
Reference in New Issue
Block a user