Added support for signing/verifying messages when using the HTTP-Redirect binding.

This commit is contained in:
Roland Hedberg
2013-02-21 12:36:05 +01:00
parent ed8b6953bb
commit cc71990164
4 changed files with 350 additions and 164 deletions

View File

@@ -208,8 +208,8 @@ class Entity(HTTPBase):
if not seid:
seid = sid(self.seed)
return {"id":seid, "version":VERSION,
"issue_instant":instant(), "issuer":self._issuer()}
return {"id": seid, "version": VERSION,
"issue_instant": instant(), "issuer": self._issuer()}
def response_args(self, message, bindings=None, descr_type=""):
info = {"in_response_to": message.id}
@@ -278,10 +278,9 @@ class Entity(HTTPBase):
:param text: The SOAP message
:return: A dictionary with two keys "body" and "header"
"""
return class_instances_from_soap_enveloped_saml_thingies(text,
[paos,
ecp,
samlp])
return class_instances_from_soap_enveloped_saml_thingies(text, [paos,
ecp,
samlp])
def unpack_soap_message(self, text):
"""
@@ -367,8 +366,8 @@ class Entity(HTTPBase):
_issuer = self._issuer(issuer)
response = response_factory(issuer=_issuer,
in_response_to = in_response_to,
status = status)
in_response_to=in_response_to,
status=status)
if consumer_url:
response.destination = consumer_url
@@ -377,7 +376,7 @@ class Entity(HTTPBase):
setattr(response, key, val)
if sign:
self.sign(response,to_sign=to_sign)
self.sign(response, to_sign=to_sign)
elif to_sign:
return signed_instance_factory(response, self.sec, to_sign)
else:
@@ -689,8 +688,8 @@ class Entity(HTTPBase):
if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]:
try:
# expected return address
kwargs["return_addr"] = self.config.endpoint(service,
binding=binding)[0]
kwargs["return_addr"] = self.config.endpoint(
service, binding=binding)[0]
except Exception:
logger.info("Not supposed to handle this!")
return None

View File

@@ -23,12 +23,15 @@ Bindings normally consists of three parts:
- how to package the information
- which protocol to use
"""
import hashlib
import urlparse
import saml2
import base64
import urllib
from saml2.s_utils import deflate_and_base64_encode
from saml2.s_utils import deflate_and_base64_encode, Unsupported
import logging
import M2Crypto
from saml2.sigver import RSA_SHA1, rsa_load, x509_rsa_loads, pem_format
logger = logging.getLogger(__name__)
@@ -51,7 +54,9 @@ FORM_SPEC = """<form method="post" action="%s">
<input type="hidden" name="RelayState" value="%s" />
</form>"""
def http_form_post_message(message, location, relay_state="", typ="SAMLRequest"):
def http_form_post_message(message, location, relay_state="",
typ="SAMLRequest"):
"""The HTTP POST binding defines a mechanism by which SAML protocol
messages may be transmitted within the base64-encoded content of a
HTML form control.
@@ -93,7 +98,47 @@ def http_form_post_message(message, location, relay_state="", typ="SAMLRequest")
# """
# return {"headers": [("Content-type", "text/xml")], "data": message}
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
class BadSignature(Exception):
"""The signature is invalid."""
pass
def sha1_digest(msg):
return hashlib.sha1(msg).digest()
class Signer(object):
"""Abstract base class for signing algorithms."""
def sign(self, msg, key):
"""Sign ``msg`` with ``key`` and return the signature."""
raise NotImplementedError
def verify(self, msg, sig, key):
"""Return True if ``sig`` is a valid signature for ``msg``."""
raise NotImplementedError
class RSASigner(Signer):
def __init__(self, digest, algo):
self.digest = digest
self.algo = algo
def sign(self, msg, key):
return key.sign(self.digest(msg), self.algo)
def verify(self, msg, sig, key):
try:
return key.verify(self.digest(msg), sig, self.algo)
except M2Crypto.RSA.RSAError, e:
raise BadSignature(e)
REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"]
RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"]
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest",
sigalg=None, key=None):
"""The HTTP Redirect binding defines a mechanism by which SAML protocol
messages can be transmitted within URL parameters.
Messages are encoded for use with this binding using a URL encoding
@@ -104,13 +149,21 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
:param message: The message
:param location: Where the message should be posted to
:param relay_state: for preserving and conveying state information
:param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart
:param sigalg: The signature algorithm to use.
:param key: Key to use for signing
:return: A tuple containing header information and a HTML message.
"""
if not isinstance(message, basestring):
message = "%s" % (message,)
_order = None
if typ in ["SAMLRequest", "SAMLResponse"]:
if typ == "SAMLRequest":
_order = REQ_ORDER
else:
_order = RESP_ORDER
args = {typ: deflate_and_base64_encode(message)}
elif typ == "SAMLart":
args = {typ: message}
@@ -120,16 +173,67 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
if relay_state:
args["RelayState"] = relay_state
if sigalg:
# sigalgs
# http://www.w3.org/2000/09/xmldsig#dsa-sha1
# http://www.w3.org/2000/09/xmldsig#rsa-sha1
args["SigAlg"] = sigalg
if sigalg == RSA_SHA1:
signer = RSASigner(sha1_digest, "sha1")
string = "&".join([urllib.urlencode({k: args[k]}) for k in _order])
args["Signature"] = base64.b64encode(signer.sign(string, key))
string = urllib.urlencode(args)
else:
raise Unsupported("Signing algorithm")
else:
string = urllib.urlencode(args)
glue_char = "&" if urlparse.urlparse(location).query else "?"
login_url = glue_char.join([location, urllib.urlencode(args)])
login_url = glue_char.join([location, string])
headers = [('Location', login_url)]
body = []
return {"headers":headers, "data":body}
return {"headers": headers, "data": body}
def verify_redirect_signature(info, cert):
"""
:param info: A dictionary as produced by parse_qs, means all values are
lists.
:param cert: A certificate to use when verifying the signature
:return: True, if signature verified
"""
if info["SigAlg"][0] == RSA_SHA1:
if "SAMLRequest" in info:
_order = REQ_ORDER
elif "SAMLResponse" in info:
_order = RESP_ORDER
else:
raise Unsupported(
"Verifying signature on something that should not be signed")
signer = RSASigner(sha1_digest, "sha1")
args = info.copy()
del args["Signature"] # everything but the signature
string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order])
_key = x509_rsa_loads(pem_format(cert))
_sign = base64.b64decode(info["Signature"][0])
try:
signer.verify(string, _sign, _key)
return True
except BadSignature:
return False
else:
raise Unsupported("Signature algorithm: %s" % info["SigAlg"])
DUMMY_NAMESPACE = "http://example.org/"
PREFIX = '<?xml version="1.0" encoding="UTF-8"?>'
def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
""" Returns a soap envelope containing a SAML request
as a text string.
@@ -170,21 +274,24 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1]
_str = _str.replace(cut1, "")
first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],))
last = _str.find(">", first+14)
cut2 = _str[first:last+1]
last = _str.find(">", first + 14)
cut2 = _str[first:last + 1]
return _str.replace(cut2, thingy)
else:
thingy.become_child_element_of(body)
return ElementTree.tostring(envelope, encoding="UTF-8")
def http_soap_message(message):
return {"headers": [("Content-type", "application/soap+xml")],
"data": make_soap_enveloped_saml_thingy(message)}
def http_paos(message, extra=None):
return {"headers":[("Content-type", "application/soap+xml")],
return {"headers": [("Content-type", "application/soap+xml")],
"data": make_soap_enveloped_saml_thingy(message, extra)}
def parse_soap_enveloped_saml(text, body_class, header_class=None):
"""Parses a SOAP enveloped SAML thing and returns header parts and body
@@ -205,7 +312,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
body = saml2.create_class_from_element_tree(body_class, sub)
except Exception:
raise Exception(
"Wrong body type (%s) in SOAP envelope" % sub.tag)
"Wrong body type (%s) in SOAP envelope" % sub.tag)
elif part.tag == '{%s}Header' % NAMESPACE:
if not header_class:
raise Exception("Header where I didn't expect one")
@@ -226,13 +333,15 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
PACKING = {
saml2.BINDING_HTTP_REDIRECT: http_redirect_message,
saml2.BINDING_HTTP_POST: http_form_post_message,
}
}
def packager( identifier ):
def packager(identifier):
try:
return PACKING[identifier]
except KeyError:
raise Exception("Unkown binding type: %s" % identifier)
def factory(binding, message, location, relay_state="", typ="SAMLRequest"):
return PACKING[binding](message, location, relay_state, typ)

View File

@@ -9,7 +9,7 @@ import sys
import hmac
# from python 2.5
if sys.version_info >= (2,5):
if sys.version_info >= (2, 5):
import hashlib
else: # before python 2.5
import sha
@@ -27,36 +27,51 @@ import zlib
logger = logging.getLogger(__name__)
class SamlException(Exception):
pass
class RequestVersionTooLow(SamlException):
pass
class RequestVersionTooHigh(SamlException):
pass
class UnknownPrincipal(SamlException):
pass
class UnsupportedBinding(SamlException):
class Unsupported(SamlException):
pass
class UnsupportedBinding(Unsupported):
pass
class VersionMismatch(Exception):
pass
class Unknown(Exception):
pass
class OtherError(Exception):
pass
class MissingValue(Exception):
pass
class PolicyError(Exception):
pass
class BadRequest(Exception):
pass
@@ -73,28 +88,29 @@ EXCEPTION2STATUS = {
Exception: samlp.STATUS_AUTHN_FAILED,
}
GENERIC_DOMAINS = "aero", "asia", "biz", "cat", "com", "coop", \
"edu", "gov", "info", "int", "jobs", "mil", "mobi", "museum", \
"name", "net", "org", "pro", "tel", "travel"
GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu",
"gov", "info", "int", "jobs", "mil", "mobi", "museum",
"name", "net", "org", "pro", "tel", "travel"]
def valid_email(emailaddress, domains = GENERIC_DOMAINS):
def valid_email(emailaddress, domains=GENERIC_DOMAINS):
"""Checks for a syntactically valid email address."""
# Email address must be at least 6 characters in total.
# Assuming noone may have addresses of the type a@com
if len(emailaddress) < 6:
return False # Address too short.
return False # Address too short.
# Split up email address into parts.
try:
localpart, domainname = emailaddress.rsplit('@', 1)
host, toplevel = domainname.rsplit('.', 1)
except ValueError:
return False # Address does not have enough parts.
return False # Address does not have enough parts.
# Check for Country code or Generic Domain.
if len(toplevel) != 2 and toplevel not in domains:
return False # Not a domain name.
return False # Not a domain name.
for i in '-_.%+.':
localpart = localpart.replace(i, "")
@@ -102,27 +118,30 @@ def valid_email(emailaddress, domains = GENERIC_DOMAINS):
host = host.replace(i, "")
if localpart.isalnum() and host.isalnum():
return True # Email address is fine.
return True # Email address is fine.
else:
return False # Email address has funny characters.
return False # Email address has funny characters.
def decode_base64_and_inflate( string ):
def decode_base64_and_inflate(string):
""" base64 decodes and then inflates according to RFC1951
:param string: a deflated and encoded string
:return: the string after decoding and inflating
"""
return zlib.decompress( base64.b64decode( string ) , -15)
return zlib.decompress(base64.b64decode(string), -15)
def deflate_and_base64_encode( string_val ):
def deflate_and_base64_encode(string_val):
"""
Deflates and the base64 encodes a string
:param string_val: The string to deflate and encode
:return: The deflated and encoded string
"""
return base64.b64encode( zlib.compress( string_val )[2:-4] )
return base64.b64encode(zlib.compress(string_val)[2:-4])
def rndstr(size=16):
"""
@@ -134,9 +153,11 @@ def rndstr(size=16):
_basech = string.ascii_letters + string.digits
return "".join([random.choice(_basech) for _ in range(size)])
def sid(seed=""):
"""The hash of the server time + seed makes an unique SID for each session.
128-bits long so it fulfills the SAML2 requirements which states 128-160 bits
128-bits long so it fulfills the SAML2 requirements which states
128-160 bits
:param seed: A seed string
:return: The hex version of the digest, prefixed by 'id-' to make it
@@ -146,7 +167,8 @@ def sid(seed=""):
ident.update(repr(time.time()))
if seed:
ident.update(seed)
return "id-"+ident.hexdigest()
return "id-" + ident.hexdigest()
def parse_attribute_map(filenames):
"""
@@ -168,6 +190,7 @@ def parse_attribute_map(filenames):
return forward, backward
def identity_attribute(form, attribute, forward_map=None):
if form == "friendly":
if attribute.friendly_name:
@@ -182,6 +205,7 @@ def identity_attribute(form, attribute, forward_map=None):
#----------------------------------------------------------------------------
def error_status_factory(info):
if isinstance(info, Exception):
try:
@@ -194,39 +218,38 @@ def error_status_factory(info):
status_code=samlp.StatusCode(
value=samlp.STATUS_RESPONDER,
status_code=samlp.StatusCode(
value=exc_val)
),
)
value=exc_val)))
else:
(errcode, text) = info
status = samlp.Status(
status_message=samlp.StatusMessage(text=text),
status_code=samlp.StatusCode(
value=samlp.STATUS_RESPONDER,
status_code=samlp.StatusCode(value=errcode)
),
)
status_code=samlp.StatusCode(value=errcode)))
return status
def success_status_factory():
return samlp.Status(status_code=samlp.StatusCode(
value=samlp.STATUS_SUCCESS))
value=samlp.STATUS_SUCCESS))
def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
return samlp.Status(
status_message=samlp.StatusMessage(text=message),
status_code=samlp.StatusCode(
value=fro,
status_code=samlp.StatusCode(value=code)))
status_code=samlp.StatusCode(value=fro,
status_code=samlp.StatusCode(value=code)))
def assertion_factory(**kwargs):
assertion = saml.Assertion(version=VERSION, id=sid(),
issue_instant=instant())
issue_instant=instant())
for key, val in kwargs.items():
setattr(assertion, key, val)
return assertion
def _attrval(val, typ=""):
if isinstance(val, list) or isinstance(val, set):
attrval = [saml.AttributeValue(text=v) for v in val]
@@ -246,6 +269,7 @@ def _attrval(val, typ=""):
# xmlns:xs="http://www.w3.org/2001/XMLSchema"
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
def do_ava(val, typ=""):
if isinstance(val, basestring):
ava = saml.AttributeValue()
@@ -253,7 +277,7 @@ def do_ava(val, typ=""):
attrval = [ava]
elif isinstance(val, list):
attrval = [do_ava(v)[0] for v in val]
elif val or val == False:
elif val or val is False:
ava = saml.AttributeValue()
ava.set_text(val)
attrval = [ava]
@@ -268,6 +292,7 @@ def do_ava(val, typ=""):
return attrval
def do_attribute(val, typ, key):
attr = saml.Attribute()
attrval = do_ava(val, typ)
@@ -276,7 +301,7 @@ def do_attribute(val, typ, key):
if isinstance(key, basestring):
attr.name = key
elif isinstance(key, tuple): # 3-tuple or 2-tuple
elif isinstance(key, tuple): # 3-tuple or 2-tuple
try:
(name, nformat, friendly) = key
except ValueError:
@@ -290,6 +315,7 @@ def do_attribute(val, typ, key):
attr.friendly_name = friendly
return attr
def do_attributes(identity):
attrs = []
if not identity:
@@ -308,6 +334,7 @@ def do_attributes(identity):
attrs.append(attr)
return attrs
def do_attribute_statement(identity):
"""
:param identity: A dictionary with fiendly names as keys
@@ -315,12 +342,14 @@ def do_attribute_statement(identity):
"""
return saml.AttributeStatement(attribute=do_attributes(identity))
def factory(klass, **kwargs):
instance = klass()
for key, val in kwargs.items():
setattr(instance, key, val)
return instance
def signature(secret, parts):
"""Generates a signature.
"""
@@ -334,6 +363,7 @@ def signature(secret, parts):
return csum.hexdigest()
def verify_signature(secret, parts):
""" Checks that the signature is correct """
if signature(secret, parts[:-1]) == parts[-1]:
@@ -344,9 +374,10 @@ def verify_signature(secret, parts):
FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#"
def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
"""
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value ) '#'
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value) '#'
Allowed attributes:
TS the login time stamp
RP the relying party entityID

View File

@@ -25,7 +25,9 @@ import logging
import random
import os
import sys
from time import mktime
import M2Crypto
from M2Crypto.X509 import load_cert_string
from saml2.samlp import Response
import xmldsig as ds
@@ -38,7 +40,7 @@ from saml2 import VERSION
from saml2.s_utils import sid
from saml2.time_util import instant
from saml2.time_util import instant, utc_now, str_to_time
from tempfile import NamedTemporaryFile
from subprocess import Popen, PIPE
@@ -47,6 +49,9 @@ logger = logging.getLogger(__name__)
SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature")
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
def signed(item):
if SIG in item.c_children.keys() and item.signature:
return True
@@ -62,6 +67,7 @@ def signed(item):
return False
def get_xmlsec_binary(paths=None):
"""
Tries to find the xmlsec1 binary.
@@ -75,7 +81,7 @@ def get_xmlsec_binary(paths=None):
bin_name = "xmlsec1"
elif os.name == "nt":
bin_name = "xmlsec1.exe"
else: # Default !?
else: # Default !?
bin_name = "xmlsec1"
if paths:
@@ -109,47 +115,36 @@ ENC_KEY_CLASS = "EncryptedKey"
_TEST_ = True
class SignatureError(Exception):
pass
class XmlsecError(Exception):
pass
class MissingKey(Exception):
pass
class DecryptError(Exception):
pass
# --------------------------------------------------------------------------
#def make_signed_instance(klass, spec, seccont, base64encode=False):
# """ Will only return signed instance if the signature
# preamble is present
#
# :param klass: The type of class the instance should be
# :param spec: The specification of attributes and children of the class
# :param seccont: The security context (instance of SecurityContext)
# :param base64encode: Whether the attribute values should be base64 encoded
# :return: A signed (or possibly unsigned) instance of the class
# """
# if "signature" in spec:
# signed_xml = seccont.sign_statement_using_xmlsec("%s" % instance,
# class_name(instance), instance.id)
# return create_class_from_xml_string(instance.__class__, signed_xml)
# else:
# return make_instance(klass, spec, base64encode)
def xmlsec_version(execname):
com_list = [execname,"--version"]
com_list = [execname, "--version"]
pof = Popen(com_list, stderr=PIPE, stdout=PIPE)
try:
return pof.stdout.read().split(" ")[1]
except Exception:
return ""
def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
base64encode=False, elements_to_sign=None):
base64encode=False, elements_to_sign=None):
"""
Creates a class instance with a specified value, the specified
class instance may be a value on a property in a defined class instance.
@@ -169,14 +164,14 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
if isinstance(val, dict):
cinst = _instance(klass, val, seccont, base64encode=base64encode,
elements_to_sign=elements_to_sign)
elements_to_sign=elements_to_sign)
else:
try:
cinst = klass().set_text(val)
except ValueError:
if not part:
cis = [_make_vals(sval, klass, seccont, klass_inst, prop,
True, base64encode, elements_to_sign) for sval in val]
True, base64encode, elements_to_sign) for sval in val]
setattr(klass_inst, prop, cis)
else:
raise
@@ -188,6 +183,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
cis = [cinst]
setattr(klass_inst, prop, cis)
def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
instance = klass()
@@ -208,30 +204,31 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
#print "## %s, %s" % (prop, klassdef)
if prop in ava:
#print "### %s" % ava[prop]
if isinstance(klassdef, list): # means there can be a list of values
if isinstance(klassdef, list):
# means there can be a list of values
_make_vals(ava[prop], klassdef[0], seccont, instance, prop,
base64encode=base64encode,
elements_to_sign=elements_to_sign)
base64encode=base64encode,
elements_to_sign=elements_to_sign)
else:
cis = _make_vals(ava[prop], klassdef, seccont, instance, prop,
True, base64encode, elements_to_sign)
True, base64encode, elements_to_sign)
setattr(instance, prop, cis)
if "extension_elements" in ava:
for item in ava["extension_elements"]:
instance.extension_elements.append(
ExtensionElement(item["tag"]).loadd(item))
ExtensionElement(item["tag"]).loadd(item))
if "extension_attributes" in ava:
for key, val in ava["extension_attributes"].items():
instance.extension_attributes[key] = val
if "signature" in ava:
elements_to_sign.append((class_name(instance), instance.id))
return instance
def signed_instance_factory(instance, seccont, elements_to_sign=None):
"""
@@ -243,8 +240,8 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
if elements_to_sign:
signed_xml = "%s" % instance
for (node_name, nodeid) in elements_to_sign:
signed_xml = seccont.sign_statement_using_xmlsec(signed_xml,
klass_namn=node_name, nodeid=nodeid)
signed_xml = seccont.sign_statement_using_xmlsec(
signed_xml, klass_namn=node_name, nodeid=nodeid)
#print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
#print "%s" % signed_xml
@@ -255,6 +252,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
# --------------------------------------------------------------------------
def create_id():
""" Create a string of 40 random characters from the set [a-p],
can be used as a unique identifier of objects.
@@ -266,6 +264,7 @@ def create_id():
ret += chr(random.randint(0, 15) + ord('a'))
return ret
def make_temp(string, suffix="", decode=True):
""" xmlsec needs files in some cases where only strings exist, hence the
need for this function. It creates a temporary file with the
@@ -288,9 +287,34 @@ def make_temp(string, suffix="", decode=True):
ntf.seek(0)
return ntf, ntf.name
def split_len(seq, length):
return [seq[i:i+length] for i in range(0, len(seq), length)]
return [seq[i:i + length] for i in range(0, len(seq), length)]
# --------------------------------------------------------------------------
M2_TIME_FORMAT = "%b %d %H:%M:%S %Y"
def to_time(_time):
assert _time.endswith(" GMT")
_time = _time[:-4]
return mktime(str_to_time(_time, M2_TIME_FORMAT))
def active_cert(key):
cert_str = pem_format(key)
certificate = load_cert_string(cert_str)
not_before = to_time(str(certificate.get_not_before()))
not_after = to_time(str(certificate.get_not_after()))
try:
assert not_before < utc_now()
assert not_after > utc_now()
return True
except AssertionError:
return False
def cert_from_key_info(key_info):
""" Get all X509 certs from a KeyInfo instance. Care is taken to make sure
that the certs are continues sequences of bytes.
@@ -307,11 +331,15 @@ def cert_from_key_info(key_info):
#print "X509Data",x509_data
x509_certificate = x509_data.x509_certificate
cert = x509_certificate.text.strip()
cert = "\n".join(split_len("".join([
s.strip() for s in cert.split()]),64))
res.append(cert)
cert = "\n".join(split_len("".join([s.strip() for s in
cert.split()]), 64))
if active_cert(cert):
res.append(cert)
else:
logger.info("Inactive cert")
return res
def cert_from_key_info_dict(key_info):
""" Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure
that the certs are continues sequences of bytes.
@@ -330,11 +358,15 @@ def cert_from_key_info_dict(key_info):
for x509_data in key_info["x509_data"]:
x509_certificate = x509_data["x509_certificate"]
cert = x509_certificate["text"].strip()
cert = "\n".join(split_len("".join([
s.strip() for s in cert.split()]),64))
res.append(cert)
cert = "\n".join(split_len("".join([s.strip() for s in
cert.split()]), 64))
if active_cert(cert):
res.append(cert)
else:
logger.info("Inactive cert")
return res
def cert_from_instance(instance):
""" Find certificates that are part of an instance
@@ -350,25 +382,30 @@ def cert_from_instance(instance):
from M2Crypto.__m2crypto import bn_to_mpi
from M2Crypto.__m2crypto import hex_to_bn
def intarr2long(arr):
return long(''.join(["%02x" % byte for byte in arr]), 16)
def dehexlify(bi):
s = hexlify(bi)
return [int(s[i]+s[i+1], 16) for i in range(0,len(s),2)]
return [int(s[i] + s[i + 1], 16) for i in range(0, len(s), 2)]
def long_to_mpi(num):
"""Converts a python integer or long to OpenSSL MPInt used by M2Crypto.
Borrowed from Snowball.Shared.Crypto"""
h = hex(num)[2:] # strip leading 0x in string
h = hex(num)[2:] # strip leading 0x in string
if len(h) % 2 == 1:
h = '0' + h # add leading 0 to get even number of hexdigits
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
h = '0' + h # add leading 0 to get even number of hexdigits
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
def base64_to_long(data):
_d = base64.urlsafe_b64decode(data + '==')
return intarr2long(dehexlify(_d))
def key_from_key_value(key_info):
res = []
for value in key_info.key_value:
@@ -376,10 +413,11 @@ def key_from_key_value(key_info):
e = base64_to_long(value.rsa_key_value.exponent)
m = base64_to_long(value.rsa_key_value.modulus)
key = M2Crypto.RSA.new_pub_key((long_to_mpi(e),
long_to_mpi(m)))
long_to_mpi(m)))
res.append(key)
return res
def key_from_key_value_dict(key_info):
res = []
if not "key_value" in key_info:
@@ -396,10 +434,28 @@ def key_from_key_value_dict(key_info):
# =============================================================================
def rsa_load(filename):
"""Read a PEM-encoded RSA key pair from a file."""
return M2Crypto.RSA.load_key(filename, M2Crypto.util.no_passphrase_callback)
def rsa_loads(key):
"""Read a PEM-encoded RSA key pair from a string."""
return M2Crypto.RSA.load_key_string(key,
M2Crypto.util.no_passphrase_callback)
def x509_rsa_loads(string):
cert = M2Crypto.X509.load_cert_string(string)
return cert.get_pubkey().get_rsa()
def pem_format(key):
return "\n".join(["-----BEGIN CERTIFICATE-----",
key,"-----END CERTIFICATE-----"])
key, "-----END CERTIFICATE-----"])
def parse_xmlsec_output(output):
""" Parse the output from xmlsec to try to find out if the
command was successfull or not.
@@ -416,12 +472,13 @@ def parse_xmlsec_output(output):
__DEBUG = 0
LOG_LINE = 60*"="+"\n%s\n"+60*"-"+"\n%s"+60*"="
LOG_LINE_2 = 60*"="+"\n%s\n%s\n"+60*"-"+"\n%s"+60*"="
LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
node_name=NODE_NAME, debug=False, node_id=None,
id_attr=""):
node_name=NODE_NAME, debug=False, node_id=None,
id_attr=""):
""" Verifies the signature of a XML document.
:param enctext: The signed XML document
@@ -481,6 +538,7 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
# ---------------------------------------------------------------------------
def read_cert_from_file(cert_file, cert_type):
""" Reads a certificate from a file. The assumption is that there is
only one certificate in the file
@@ -516,6 +574,7 @@ def read_cert_from_file(cert_file, cert_type):
data = open(cert_file).read()
return base64.b64encode(str(data))
def security_context(conf, debug=None):
""" Creates a security context based on the configuration
@@ -535,14 +594,15 @@ def security_context(conf, debug=None):
_only_md = False
return SecurityContext(conf.xmlsec_binary, conf.key_file,
cert_file=conf.cert_file, metadata=metadata,
debug=debug, only_use_keys_in_metadata=_only_md)
cert_file=conf.cert_file, metadata=metadata,
debug=debug, only_use_keys_in_metadata=_only_md)
class SecurityContext(object):
def __init__(self, xmlsec_binary, key_file="", key_type= "pem",
cert_file="", cert_type="pem", metadata=None,
debug=False, template="", encrypt_key_type="des-192",
only_use_keys_in_metadata=False):
def __init__(self, xmlsec_binary, key_file="", key_type="pem",
cert_file="", cert_type="pem", metadata=None,
debug=False, template="", encrypt_key_type="des-192",
only_use_keys_in_metadata=False):
self.xmlsec = xmlsec_binary
@@ -592,12 +652,9 @@ class SecurityContext(object):
_, fil = make_temp("%s" % text, decode=False)
ntf = NamedTemporaryFile()
com_list = [self.xmlsec, "--encrypt",
"--pubkey-pem", recv_key,
"--session-key", key_type,
"--xml-data", fil,
"--output", ntf.name,
template]
com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key,
"--session-key", key_type, "--xml-data", fil,
"--output", ntf.name, template]
logger.debug("Encryption command: %s" % " ".join(com_list))
@@ -625,11 +682,9 @@ class SecurityContext(object):
_, fil = make_temp("%s" % enctext, decode=False)
ntf = NamedTemporaryFile()
com_list = [self.xmlsec, "--decrypt",
"--privkey-pem", self.key_file,
"--output", ntf.name,
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS,
fil]
com_list = [self.xmlsec, "--decrypt", "--privkey-pem",
self.key_file, "--output", ntf.name,
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, fil]
logger.debug("Decrypt command: %s" % " ".join(com_list))
@@ -646,9 +701,8 @@ class SecurityContext(object):
ntf.seek(0)
return ntf.read()
def verify_signature(self, enctext, cert_file=None, cert_type="pem",
node_name=NODE_NAME, node_id=None, id_attr=""):
node_name=NODE_NAME, node_id=None, id_attr=""):
""" Verifies the signature of a XML document.
:param enctext: The XML document as a string
@@ -773,22 +827,22 @@ class SecurityContext(object):
must, origdoc)
def correctly_signed_authn_query(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml, "authn_query",
must, origdoc)
def correctly_signed_logout_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml, "logout_request",
must, origdoc)
def correctly_signed_logout_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml, "logout_response",
must, origdoc)
def correctly_signed_attribute_query(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml, "attribute_query",
must, origdoc)
@@ -799,31 +853,31 @@ class SecurityContext(object):
origdoc)
def correctly_signed_authz_decision_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml,
"authz_decision_response", must,
origdoc)
def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml,
"name_id_mapping_request",
must, origdoc)
def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml,
"name_id_mapping_response",
must, origdoc)
def correctly_signed_artifact_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml,
"artifact_request",
must, origdoc)
def correctly_signed_artifact_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml,
"artifact_response",
must, origdoc)
@@ -835,19 +889,19 @@ class SecurityContext(object):
must, origdoc)
def correctly_signed_manage_name_id_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml,
"manage_name_id_response", must,
origdoc)
def correctly_signed_assertion_id_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml,
"assertion_id_request", must,
origdoc)
def correctly_signed_assertion_id_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None):
return self.correctly_signed_message(decoded_xml, "assertion", must,
origdoc)
@@ -882,18 +936,16 @@ class SecurityContext(object):
try:
self._check_signature(decoded_xml, assertion,
class_name(assertion), origdoc)
class_name(assertion), origdoc)
except Exception, exc:
logger.error("correctly_signed_response: %s" % exc)
raise
return response
#--------------------------------------------------------------------------
# SIGNATURE PART
#--------------------------------------------------------------------------
def sign_statement_using_xmlsec(self, statement, klass_namn, key=None,
key_file=None, nodeid=None, id_attr=""):
"""Sign a SAML statement using xmlsec.
@@ -917,7 +969,6 @@ class SecurityContext(object):
_, fil = make_temp("%s" % statement, decode=False)
ntf = NamedTemporaryFile()
com_list = [self.xmlsec, "--sign",
@@ -975,10 +1026,8 @@ class SecurityContext(object):
:return: The signed statement
"""
return self.sign_statement_using_xmlsec(statement,
class_name(samlp.AttributeQuery()),
key, key_file, nodeid,
id_attr=id_attr)
return self.sign_statement_using_xmlsec(statement, class_name(
samlp.AttributeQuery()), key, key_file, nodeid, id_attr=id_attr)
def multiple_signatures(self, statement, to_sign, key=None, key_file=None):
"""
@@ -991,15 +1040,15 @@ class SecurityContext(object):
:param key_file: A file that contains the key to be used
:return: A possibly multiple signed statement
"""
for (item, id, id_attr) in to_sign:
if not id:
for (item, sid, id_attr) in to_sign:
if not sid:
if not item.id:
id = item.id = sid()
sid = item.id = sid()
else:
id = item.id
sid = item.id
if not item.signature:
item.signature = pre_signature_part(id, self.cert_file)
item.signature = pre_signature_part(sid, self.cert_file)
statement = self.sign_statement_using_xmlsec(statement,
class_name(item),
@@ -1024,32 +1073,30 @@ def pre_signature_part(ident, public_key=None, identifier=None):
:return: A preset signature part
"""
signature_method = ds.SignatureMethod(algorithm = ds.SIG_RSA_SHA1)
signature_method = ds.SignatureMethod(algorithm=ds.SIG_RSA_SHA1)
canonicalization_method = ds.CanonicalizationMethod(
algorithm = ds.ALG_EXC_C14N)
trans0 = ds.Transform(algorithm = ds.TRANSFORM_ENVELOPED)
trans1 = ds.Transform(algorithm = ds.ALG_EXC_C14N)
transforms = ds.Transforms(transform = [trans0, trans1])
digest_method = ds.DigestMethod(algorithm = ds.DIGEST_SHA1)
algorithm=ds.ALG_EXC_C14N)
trans0 = ds.Transform(algorithm=ds.TRANSFORM_ENVELOPED)
trans1 = ds.Transform(algorithm=ds.ALG_EXC_C14N)
transforms = ds.Transforms(transform=[trans0, trans1])
digest_method = ds.DigestMethod(algorithm=ds.DIGEST_SHA1)
reference = ds.Reference(uri = "#%s" % ident,
digest_value = ds.DigestValue(),
transforms = transforms,
digest_method = digest_method)
reference = ds.Reference(uri="#%s" % ident, digest_value=ds.DigestValue(),
transforms=transforms, digest_method=digest_method)
signed_info = ds.SignedInfo(signature_method = signature_method,
canonicalization_method = canonicalization_method,
reference = reference)
signed_info = ds.SignedInfo(signature_method=signature_method,
canonicalization_method=canonicalization_method,
reference=reference)
signature = ds.Signature(signed_info=signed_info,
signature_value=ds.SignatureValue())
signature = ds.Signature(signed_info=signed_info,
signature_value=ds.SignatureValue())
if identifier:
signature.id = "Signature%d" % identifier
if public_key:
x509_data = ds.X509Data(x509_certificate=[ds.X509Certificate(
text=public_key)])
x509_data = ds.X509Data(
x509_certificate=[ds.X509Certificate(text=public_key)])
key_info = ds.KeyInfo(x509_data=x509_data)
signature.key_info = key_info