From cc71990164832ae37dd90f780124eb147997093b Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 21 Feb 2013 12:36:05 +0100 Subject: [PATCH] Added support for signing/verifying messages when using the HTTP-Redirect binding. --- src/saml2/entity.py | 21 ++-- src/saml2/pack.py | 131 +++++++++++++++++++-- src/saml2/s_utils.py | 93 ++++++++++----- src/saml2/sigver.py | 269 +++++++++++++++++++++++++------------------ 4 files changed, 350 insertions(+), 164 deletions(-) diff --git a/src/saml2/entity.py b/src/saml2/entity.py index de241c6..205f6d7 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -208,8 +208,8 @@ class Entity(HTTPBase): if not seid: seid = sid(self.seed) - return {"id":seid, "version":VERSION, - "issue_instant":instant(), "issuer":self._issuer()} + return {"id": seid, "version": VERSION, + "issue_instant": instant(), "issuer": self._issuer()} def response_args(self, message, bindings=None, descr_type=""): info = {"in_response_to": message.id} @@ -278,10 +278,9 @@ class Entity(HTTPBase): :param text: The SOAP message :return: A dictionary with two keys "body" and "header" """ - return class_instances_from_soap_enveloped_saml_thingies(text, - [paos, - ecp, - samlp]) + return class_instances_from_soap_enveloped_saml_thingies(text, [paos, + ecp, + samlp]) def unpack_soap_message(self, text): """ @@ -367,8 +366,8 @@ class Entity(HTTPBase): _issuer = self._issuer(issuer) response = response_factory(issuer=_issuer, - in_response_to = in_response_to, - status = status) + in_response_to=in_response_to, + status=status) if consumer_url: response.destination = consumer_url @@ -377,7 +376,7 @@ class Entity(HTTPBase): setattr(response, key, val) if sign: - self.sign(response,to_sign=to_sign) + self.sign(response, to_sign=to_sign) elif to_sign: return signed_instance_factory(response, self.sec, to_sign) else: @@ -689,8 +688,8 @@ class Entity(HTTPBase): if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]: try: # expected return address - kwargs["return_addr"] = self.config.endpoint(service, - binding=binding)[0] + kwargs["return_addr"] = self.config.endpoint( + service, binding=binding)[0] except Exception: logger.info("Not supposed to handle this!") return None diff --git a/src/saml2/pack.py b/src/saml2/pack.py index a985e7a..4cf5996 100644 --- a/src/saml2/pack.py +++ b/src/saml2/pack.py @@ -23,12 +23,15 @@ Bindings normally consists of three parts: - how to package the information - which protocol to use """ +import hashlib import urlparse import saml2 import base64 import urllib -from saml2.s_utils import deflate_and_base64_encode +from saml2.s_utils import deflate_and_base64_encode, Unsupported import logging +import M2Crypto +from saml2.sigver import RSA_SHA1, rsa_load, x509_rsa_loads, pem_format logger = logging.getLogger(__name__) @@ -51,7 +54,9 @@ FORM_SPEC = """
""" -def http_form_post_message(message, location, relay_state="", typ="SAMLRequest"): + +def http_form_post_message(message, location, relay_state="", + typ="SAMLRequest"): """The HTTP POST binding defines a mechanism by which SAML protocol messages may be transmitted within the base64-encoded content of a HTML form control. @@ -93,7 +98,47 @@ def http_form_post_message(message, location, relay_state="", typ="SAMLRequest") # """ # return {"headers": [("Content-type", "text/xml")], "data": message} -def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"): + +class BadSignature(Exception): + """The signature is invalid.""" + pass + + +def sha1_digest(msg): + return hashlib.sha1(msg).digest() + + +class Signer(object): + """Abstract base class for signing algorithms.""" + def sign(self, msg, key): + """Sign ``msg`` with ``key`` and return the signature.""" + raise NotImplementedError + + def verify(self, msg, sig, key): + """Return True if ``sig`` is a valid signature for ``msg``.""" + raise NotImplementedError + + +class RSASigner(Signer): + def __init__(self, digest, algo): + self.digest = digest + self.algo = algo + + def sign(self, msg, key): + return key.sign(self.digest(msg), self.algo) + + def verify(self, msg, sig, key): + try: + return key.verify(self.digest(msg), sig, self.algo) + except M2Crypto.RSA.RSAError, e: + raise BadSignature(e) + + +REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"] +RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"] + +def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", + sigalg=None, key=None): """The HTTP Redirect binding defines a mechanism by which SAML protocol messages can be transmitted within URL parameters. Messages are encoded for use with this binding using a URL encoding @@ -104,13 +149,21 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"): :param message: The message :param location: Where the message should be posted to :param relay_state: for preserving and conveying state information + :param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart + :param sigalg: The signature algorithm to use. + :param key: Key to use for signing :return: A tuple containing header information and a HTML message. """ if not isinstance(message, basestring): message = "%s" % (message,) + _order = None if typ in ["SAMLRequest", "SAMLResponse"]: + if typ == "SAMLRequest": + _order = REQ_ORDER + else: + _order = RESP_ORDER args = {typ: deflate_and_base64_encode(message)} elif typ == "SAMLart": args = {typ: message} @@ -120,16 +173,67 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"): if relay_state: args["RelayState"] = relay_state + if sigalg: + # sigalgs + # http://www.w3.org/2000/09/xmldsig#dsa-sha1 + # http://www.w3.org/2000/09/xmldsig#rsa-sha1 + + args["SigAlg"] = sigalg + + if sigalg == RSA_SHA1: + signer = RSASigner(sha1_digest, "sha1") + string = "&".join([urllib.urlencode({k: args[k]}) for k in _order]) + args["Signature"] = base64.b64encode(signer.sign(string, key)) + string = urllib.urlencode(args) + else: + raise Unsupported("Signing algorithm") + else: + string = urllib.urlencode(args) + glue_char = "&" if urlparse.urlparse(location).query else "?" - login_url = glue_char.join([location, urllib.urlencode(args)]) + login_url = glue_char.join([location, string]) headers = [('Location', login_url)] body = [] - return {"headers":headers, "data":body} + return {"headers": headers, "data": body} + + +def verify_redirect_signature(info, cert): + """ + + :param info: A dictionary as produced by parse_qs, means all values are + lists. + :param cert: A certificate to use when verifying the signature + :return: True, if signature verified + """ + + if info["SigAlg"][0] == RSA_SHA1: + if "SAMLRequest" in info: + _order = REQ_ORDER + elif "SAMLResponse" in info: + _order = RESP_ORDER + else: + raise Unsupported( + "Verifying signature on something that should not be signed") + signer = RSASigner(sha1_digest, "sha1") + args = info.copy() + del args["Signature"] # everything but the signature + string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order]) + _key = x509_rsa_loads(pem_format(cert)) + _sign = base64.b64decode(info["Signature"][0]) + try: + signer.verify(string, _sign, _key) + return True + except BadSignature: + return False + else: + raise Unsupported("Signature algorithm: %s" % info["SigAlg"]) + DUMMY_NAMESPACE = "http://example.org/" PREFIX = '' + def make_soap_enveloped_saml_thingy(thingy, header_parts=None): """ Returns a soap envelope containing a SAML request as a text string. @@ -170,21 +274,24 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None): cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1] _str = _str.replace(cut1, "") first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],)) - last = _str.find(">", first+14) - cut2 = _str[first:last+1] + last = _str.find(">", first + 14) + cut2 = _str[first:last + 1] return _str.replace(cut2, thingy) else: thingy.become_child_element_of(body) return ElementTree.tostring(envelope, encoding="UTF-8") + def http_soap_message(message): return {"headers": [("Content-type", "application/soap+xml")], "data": make_soap_enveloped_saml_thingy(message)} + def http_paos(message, extra=None): - return {"headers":[("Content-type", "application/soap+xml")], + return {"headers": [("Content-type", "application/soap+xml")], "data": make_soap_enveloped_saml_thingy(message, extra)} + def parse_soap_enveloped_saml(text, body_class, header_class=None): """Parses a SOAP enveloped SAML thing and returns header parts and body @@ -205,7 +312,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): body = saml2.create_class_from_element_tree(body_class, sub) except Exception: raise Exception( - "Wrong body type (%s) in SOAP envelope" % sub.tag) + "Wrong body type (%s) in SOAP envelope" % sub.tag) elif part.tag == '{%s}Header' % NAMESPACE: if not header_class: raise Exception("Header where I didn't expect one") @@ -226,13 +333,15 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): PACKING = { saml2.BINDING_HTTP_REDIRECT: http_redirect_message, saml2.BINDING_HTTP_POST: http_form_post_message, - } +} -def packager( identifier ): + +def packager(identifier): try: return PACKING[identifier] except KeyError: raise Exception("Unkown binding type: %s" % identifier) + def factory(binding, message, location, relay_state="", typ="SAMLRequest"): return PACKING[binding](message, location, relay_state, typ) \ No newline at end of file diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py index cb5e1d5..1b28dc1 100644 --- a/src/saml2/s_utils.py +++ b/src/saml2/s_utils.py @@ -9,7 +9,7 @@ import sys import hmac # from python 2.5 -if sys.version_info >= (2,5): +if sys.version_info >= (2, 5): import hashlib else: # before python 2.5 import sha @@ -27,36 +27,51 @@ import zlib logger = logging.getLogger(__name__) + class SamlException(Exception): pass + class RequestVersionTooLow(SamlException): pass + class RequestVersionTooHigh(SamlException): pass + class UnknownPrincipal(SamlException): pass -class UnsupportedBinding(SamlException): + +class Unsupported(SamlException): pass + +class UnsupportedBinding(Unsupported): + pass + + class VersionMismatch(Exception): pass + class Unknown(Exception): pass + class OtherError(Exception): pass + class MissingValue(Exception): pass + class PolicyError(Exception): pass + class BadRequest(Exception): pass @@ -73,28 +88,29 @@ EXCEPTION2STATUS = { Exception: samlp.STATUS_AUTHN_FAILED, } -GENERIC_DOMAINS = "aero", "asia", "biz", "cat", "com", "coop", \ - "edu", "gov", "info", "int", "jobs", "mil", "mobi", "museum", \ - "name", "net", "org", "pro", "tel", "travel" +GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu", + "gov", "info", "int", "jobs", "mil", "mobi", "museum", + "name", "net", "org", "pro", "tel", "travel"] -def valid_email(emailaddress, domains = GENERIC_DOMAINS): + +def valid_email(emailaddress, domains=GENERIC_DOMAINS): """Checks for a syntactically valid email address.""" # Email address must be at least 6 characters in total. # Assuming noone may have addresses of the type a@com if len(emailaddress) < 6: - return False # Address too short. + return False # Address too short. # Split up email address into parts. try: localpart, domainname = emailaddress.rsplit('@', 1) host, toplevel = domainname.rsplit('.', 1) except ValueError: - return False # Address does not have enough parts. + return False # Address does not have enough parts. # Check for Country code or Generic Domain. if len(toplevel) != 2 and toplevel not in domains: - return False # Not a domain name. + return False # Not a domain name. for i in '-_.%+.': localpart = localpart.replace(i, "") @@ -102,27 +118,30 @@ def valid_email(emailaddress, domains = GENERIC_DOMAINS): host = host.replace(i, "") if localpart.isalnum() and host.isalnum(): - return True # Email address is fine. + return True # Email address is fine. else: - return False # Email address has funny characters. + return False # Email address has funny characters. -def decode_base64_and_inflate( string ): + +def decode_base64_and_inflate(string): """ base64 decodes and then inflates according to RFC1951 :param string: a deflated and encoded string :return: the string after decoding and inflating """ - return zlib.decompress( base64.b64decode( string ) , -15) + return zlib.decompress(base64.b64decode(string), -15) -def deflate_and_base64_encode( string_val ): + +def deflate_and_base64_encode(string_val): """ Deflates and the base64 encodes a string :param string_val: The string to deflate and encode :return: The deflated and encoded string """ - return base64.b64encode( zlib.compress( string_val )[2:-4] ) + return base64.b64encode(zlib.compress(string_val)[2:-4]) + def rndstr(size=16): """ @@ -134,9 +153,11 @@ def rndstr(size=16): _basech = string.ascii_letters + string.digits return "".join([random.choice(_basech) for _ in range(size)]) + def sid(seed=""): """The hash of the server time + seed makes an unique SID for each session. - 128-bits long so it fulfills the SAML2 requirements which states 128-160 bits + 128-bits long so it fulfills the SAML2 requirements which states + 128-160 bits :param seed: A seed string :return: The hex version of the digest, prefixed by 'id-' to make it @@ -146,7 +167,8 @@ def sid(seed=""): ident.update(repr(time.time())) if seed: ident.update(seed) - return "id-"+ident.hexdigest() + return "id-" + ident.hexdigest() + def parse_attribute_map(filenames): """ @@ -168,6 +190,7 @@ def parse_attribute_map(filenames): return forward, backward + def identity_attribute(form, attribute, forward_map=None): if form == "friendly": if attribute.friendly_name: @@ -182,6 +205,7 @@ def identity_attribute(form, attribute, forward_map=None): #---------------------------------------------------------------------------- + def error_status_factory(info): if isinstance(info, Exception): try: @@ -194,39 +218,38 @@ def error_status_factory(info): status_code=samlp.StatusCode( value=samlp.STATUS_RESPONDER, status_code=samlp.StatusCode( - value=exc_val) - ), - ) + value=exc_val))) else: (errcode, text) = info status = samlp.Status( status_message=samlp.StatusMessage(text=text), status_code=samlp.StatusCode( value=samlp.STATUS_RESPONDER, - status_code=samlp.StatusCode(value=errcode) - ), - ) + status_code=samlp.StatusCode(value=errcode))) return status + def success_status_factory(): return samlp.Status(status_code=samlp.StatusCode( - value=samlp.STATUS_SUCCESS)) + value=samlp.STATUS_SUCCESS)) + def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER): return samlp.Status( status_message=samlp.StatusMessage(text=message), - status_code=samlp.StatusCode( - value=fro, - status_code=samlp.StatusCode(value=code))) + status_code=samlp.StatusCode(value=fro, + status_code=samlp.StatusCode(value=code))) + def assertion_factory(**kwargs): assertion = saml.Assertion(version=VERSION, id=sid(), - issue_instant=instant()) + issue_instant=instant()) for key, val in kwargs.items(): setattr(assertion, key, val) return assertion + def _attrval(val, typ=""): if isinstance(val, list) or isinstance(val, set): attrval = [saml.AttributeValue(text=v) for v in val] @@ -246,6 +269,7 @@ def _attrval(val, typ=""): # xmlns:xs="http://www.w3.org/2001/XMLSchema" # xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + def do_ava(val, typ=""): if isinstance(val, basestring): ava = saml.AttributeValue() @@ -253,7 +277,7 @@ def do_ava(val, typ=""): attrval = [ava] elif isinstance(val, list): attrval = [do_ava(v)[0] for v in val] - elif val or val == False: + elif val or val is False: ava = saml.AttributeValue() ava.set_text(val) attrval = [ava] @@ -268,6 +292,7 @@ def do_ava(val, typ=""): return attrval + def do_attribute(val, typ, key): attr = saml.Attribute() attrval = do_ava(val, typ) @@ -276,7 +301,7 @@ def do_attribute(val, typ, key): if isinstance(key, basestring): attr.name = key - elif isinstance(key, tuple): # 3-tuple or 2-tuple + elif isinstance(key, tuple): # 3-tuple or 2-tuple try: (name, nformat, friendly) = key except ValueError: @@ -290,6 +315,7 @@ def do_attribute(val, typ, key): attr.friendly_name = friendly return attr + def do_attributes(identity): attrs = [] if not identity: @@ -308,6 +334,7 @@ def do_attributes(identity): attrs.append(attr) return attrs + def do_attribute_statement(identity): """ :param identity: A dictionary with fiendly names as keys @@ -315,12 +342,14 @@ def do_attribute_statement(identity): """ return saml.AttributeStatement(attribute=do_attributes(identity)) + def factory(klass, **kwargs): instance = klass() for key, val in kwargs.items(): setattr(instance, key, val) return instance + def signature(secret, parts): """Generates a signature. """ @@ -334,6 +363,7 @@ def signature(secret, parts): return csum.hexdigest() + def verify_signature(secret, parts): """ Checks that the signature is correct """ if signature(secret, parts[:-1]) == parts[-1]: @@ -344,9 +374,10 @@ def verify_signature(secret, parts): FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#" + def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion): """ - 'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value ) '#' + 'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value) '#' Allowed attributes: TS the login time stamp RP the relying party entityID diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index dd647cf..f4e58fe 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -25,7 +25,9 @@ import logging import random import os import sys +from time import mktime import M2Crypto +from M2Crypto.X509 import load_cert_string from saml2.samlp import Response import xmldsig as ds @@ -38,7 +40,7 @@ from saml2 import VERSION from saml2.s_utils import sid -from saml2.time_util import instant +from saml2.time_util import instant, utc_now, str_to_time from tempfile import NamedTemporaryFile from subprocess import Popen, PIPE @@ -47,6 +49,9 @@ logger = logging.getLogger(__name__) SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature") +RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1" + + def signed(item): if SIG in item.c_children.keys() and item.signature: return True @@ -62,6 +67,7 @@ def signed(item): return False + def get_xmlsec_binary(paths=None): """ Tries to find the xmlsec1 binary. @@ -75,7 +81,7 @@ def get_xmlsec_binary(paths=None): bin_name = "xmlsec1" elif os.name == "nt": bin_name = "xmlsec1.exe" - else: # Default !? + else: # Default !? bin_name = "xmlsec1" if paths: @@ -109,47 +115,36 @@ ENC_KEY_CLASS = "EncryptedKey" _TEST_ = True + class SignatureError(Exception): pass + class XmlsecError(Exception): pass + class MissingKey(Exception): pass + class DecryptError(Exception): pass # -------------------------------------------------------------------------- -#def make_signed_instance(klass, spec, seccont, base64encode=False): -# """ Will only return signed instance if the signature -# preamble is present -# -# :param klass: The type of class the instance should be -# :param spec: The specification of attributes and children of the class -# :param seccont: The security context (instance of SecurityContext) -# :param base64encode: Whether the attribute values should be base64 encoded -# :return: A signed (or possibly unsigned) instance of the class -# """ -# if "signature" in spec: -# signed_xml = seccont.sign_statement_using_xmlsec("%s" % instance, -# class_name(instance), instance.id) -# return create_class_from_xml_string(instance.__class__, signed_xml) -# else: -# return make_instance(klass, spec, base64encode) def xmlsec_version(execname): - com_list = [execname,"--version"] + com_list = [execname, "--version"] pof = Popen(com_list, stderr=PIPE, stdout=PIPE) try: return pof.stdout.read().split(" ")[1] except Exception: return "" - + + def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, - base64encode=False, elements_to_sign=None): + base64encode=False, elements_to_sign=None): """ Creates a class instance with a specified value, the specified class instance may be a value on a property in a defined class instance. @@ -169,14 +164,14 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, if isinstance(val, dict): cinst = _instance(klass, val, seccont, base64encode=base64encode, - elements_to_sign=elements_to_sign) + elements_to_sign=elements_to_sign) else: try: cinst = klass().set_text(val) except ValueError: if not part: cis = [_make_vals(sval, klass, seccont, klass_inst, prop, - True, base64encode, elements_to_sign) for sval in val] + True, base64encode, elements_to_sign) for sval in val] setattr(klass_inst, prop, cis) else: raise @@ -188,6 +183,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, cis = [cinst] setattr(klass_inst, prop, cis) + def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): instance = klass() @@ -208,30 +204,31 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): #print "## %s, %s" % (prop, klassdef) if prop in ava: #print "### %s" % ava[prop] - if isinstance(klassdef, list): # means there can be a list of values + if isinstance(klassdef, list): + # means there can be a list of values _make_vals(ava[prop], klassdef[0], seccont, instance, prop, - base64encode=base64encode, - elements_to_sign=elements_to_sign) + base64encode=base64encode, + elements_to_sign=elements_to_sign) else: cis = _make_vals(ava[prop], klassdef, seccont, instance, prop, - True, base64encode, elements_to_sign) + True, base64encode, elements_to_sign) setattr(instance, prop, cis) if "extension_elements" in ava: for item in ava["extension_elements"]: instance.extension_elements.append( - ExtensionElement(item["tag"]).loadd(item)) + ExtensionElement(item["tag"]).loadd(item)) if "extension_attributes" in ava: for key, val in ava["extension_attributes"].items(): instance.extension_attributes[key] = val - - + if "signature" in ava: elements_to_sign.append((class_name(instance), instance.id)) return instance + def signed_instance_factory(instance, seccont, elements_to_sign=None): """ @@ -243,8 +240,8 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): if elements_to_sign: signed_xml = "%s" % instance for (node_name, nodeid) in elements_to_sign: - signed_xml = seccont.sign_statement_using_xmlsec(signed_xml, - klass_namn=node_name, nodeid=nodeid) + signed_xml = seccont.sign_statement_using_xmlsec( + signed_xml, klass_namn=node_name, nodeid=nodeid) #print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" #print "%s" % signed_xml @@ -255,6 +252,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): # -------------------------------------------------------------------------- + def create_id(): """ Create a string of 40 random characters from the set [a-p], can be used as a unique identifier of objects. @@ -266,6 +264,7 @@ def create_id(): ret += chr(random.randint(0, 15) + ord('a')) return ret + def make_temp(string, suffix="", decode=True): """ xmlsec needs files in some cases where only strings exist, hence the need for this function. It creates a temporary file with the @@ -288,9 +287,34 @@ def make_temp(string, suffix="", decode=True): ntf.seek(0) return ntf, ntf.name + def split_len(seq, length): - return [seq[i:i+length] for i in range(0, len(seq), length)] - + return [seq[i:i + length] for i in range(0, len(seq), length)] + +# -------------------------------------------------------------------------- + +M2_TIME_FORMAT = "%b %d %H:%M:%S %Y" + + +def to_time(_time): + assert _time.endswith(" GMT") + _time = _time[:-4] + return mktime(str_to_time(_time, M2_TIME_FORMAT)) + + +def active_cert(key): + cert_str = pem_format(key) + certificate = load_cert_string(cert_str) + not_before = to_time(str(certificate.get_not_before())) + not_after = to_time(str(certificate.get_not_after())) + try: + assert not_before < utc_now() + assert not_after > utc_now() + return True + except AssertionError: + return False + + def cert_from_key_info(key_info): """ Get all X509 certs from a KeyInfo instance. Care is taken to make sure that the certs are continues sequences of bytes. @@ -307,11 +331,15 @@ def cert_from_key_info(key_info): #print "X509Data",x509_data x509_certificate = x509_data.x509_certificate cert = x509_certificate.text.strip() - cert = "\n".join(split_len("".join([ - s.strip() for s in cert.split()]),64)) - res.append(cert) + cert = "\n".join(split_len("".join([s.strip() for s in + cert.split()]), 64)) + if active_cert(cert): + res.append(cert) + else: + logger.info("Inactive cert") return res + def cert_from_key_info_dict(key_info): """ Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure that the certs are continues sequences of bytes. @@ -330,11 +358,15 @@ def cert_from_key_info_dict(key_info): for x509_data in key_info["x509_data"]: x509_certificate = x509_data["x509_certificate"] cert = x509_certificate["text"].strip() - cert = "\n".join(split_len("".join([ - s.strip() for s in cert.split()]),64)) - res.append(cert) + cert = "\n".join(split_len("".join([s.strip() for s in + cert.split()]), 64)) + if active_cert(cert): + res.append(cert) + else: + logger.info("Inactive cert") return res + def cert_from_instance(instance): """ Find certificates that are part of an instance @@ -350,25 +382,30 @@ def cert_from_instance(instance): from M2Crypto.__m2crypto import bn_to_mpi from M2Crypto.__m2crypto import hex_to_bn + def intarr2long(arr): return long(''.join(["%02x" % byte for byte in arr]), 16) + def dehexlify(bi): s = hexlify(bi) - return [int(s[i]+s[i+1], 16) for i in range(0,len(s),2)] + return [int(s[i] + s[i + 1], 16) for i in range(0, len(s), 2)] + def long_to_mpi(num): """Converts a python integer or long to OpenSSL MPInt used by M2Crypto. Borrowed from Snowball.Shared.Crypto""" - h = hex(num)[2:] # strip leading 0x in string + h = hex(num)[2:] # strip leading 0x in string if len(h) % 2 == 1: - h = '0' + h # add leading 0 to get even number of hexdigits - return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum + h = '0' + h # add leading 0 to get even number of hexdigits + return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum + def base64_to_long(data): _d = base64.urlsafe_b64decode(data + '==') return intarr2long(dehexlify(_d)) + def key_from_key_value(key_info): res = [] for value in key_info.key_value: @@ -376,10 +413,11 @@ def key_from_key_value(key_info): e = base64_to_long(value.rsa_key_value.exponent) m = base64_to_long(value.rsa_key_value.modulus) key = M2Crypto.RSA.new_pub_key((long_to_mpi(e), - long_to_mpi(m))) + long_to_mpi(m))) res.append(key) return res + def key_from_key_value_dict(key_info): res = [] if not "key_value" in key_info: @@ -396,10 +434,28 @@ def key_from_key_value_dict(key_info): # ============================================================================= + +def rsa_load(filename): + """Read a PEM-encoded RSA key pair from a file.""" + return M2Crypto.RSA.load_key(filename, M2Crypto.util.no_passphrase_callback) + + +def rsa_loads(key): + """Read a PEM-encoded RSA key pair from a string.""" + return M2Crypto.RSA.load_key_string(key, + M2Crypto.util.no_passphrase_callback) + + +def x509_rsa_loads(string): + cert = M2Crypto.X509.load_cert_string(string) + return cert.get_pubkey().get_rsa() + + def pem_format(key): return "\n".join(["-----BEGIN CERTIFICATE-----", - key,"-----END CERTIFICATE-----"]) + key, "-----END CERTIFICATE-----"]) + def parse_xmlsec_output(output): """ Parse the output from xmlsec to try to find out if the command was successfull or not. @@ -416,12 +472,13 @@ def parse_xmlsec_output(output): __DEBUG = 0 -LOG_LINE = 60*"="+"\n%s\n"+60*"-"+"\n%s"+60*"=" -LOG_LINE_2 = 60*"="+"\n%s\n%s\n"+60*"-"+"\n%s"+60*"=" +LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" +LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" + def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem", - node_name=NODE_NAME, debug=False, node_id=None, - id_attr=""): + node_name=NODE_NAME, debug=False, node_id=None, + id_attr=""): """ Verifies the signature of a XML document. :param enctext: The signed XML document @@ -481,6 +538,7 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem", # --------------------------------------------------------------------------- + def read_cert_from_file(cert_file, cert_type): """ Reads a certificate from a file. The assumption is that there is only one certificate in the file @@ -516,6 +574,7 @@ def read_cert_from_file(cert_file, cert_type): data = open(cert_file).read() return base64.b64encode(str(data)) + def security_context(conf, debug=None): """ Creates a security context based on the configuration @@ -535,14 +594,15 @@ def security_context(conf, debug=None): _only_md = False return SecurityContext(conf.xmlsec_binary, conf.key_file, - cert_file=conf.cert_file, metadata=metadata, - debug=debug, only_use_keys_in_metadata=_only_md) + cert_file=conf.cert_file, metadata=metadata, + debug=debug, only_use_keys_in_metadata=_only_md) + class SecurityContext(object): - def __init__(self, xmlsec_binary, key_file="", key_type= "pem", - cert_file="", cert_type="pem", metadata=None, - debug=False, template="", encrypt_key_type="des-192", - only_use_keys_in_metadata=False): + def __init__(self, xmlsec_binary, key_file="", key_type="pem", + cert_file="", cert_type="pem", metadata=None, + debug=False, template="", encrypt_key_type="des-192", + only_use_keys_in_metadata=False): self.xmlsec = xmlsec_binary @@ -592,12 +652,9 @@ class SecurityContext(object): _, fil = make_temp("%s" % text, decode=False) ntf = NamedTemporaryFile() - com_list = [self.xmlsec, "--encrypt", - "--pubkey-pem", recv_key, - "--session-key", key_type, - "--xml-data", fil, - "--output", ntf.name, - template] + com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key, + "--session-key", key_type, "--xml-data", fil, + "--output", ntf.name, template] logger.debug("Encryption command: %s" % " ".join(com_list)) @@ -625,11 +682,9 @@ class SecurityContext(object): _, fil = make_temp("%s" % enctext, decode=False) ntf = NamedTemporaryFile() - com_list = [self.xmlsec, "--decrypt", - "--privkey-pem", self.key_file, - "--output", ntf.name, - "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, - fil] + com_list = [self.xmlsec, "--decrypt", "--privkey-pem", + self.key_file, "--output", ntf.name, + "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, fil] logger.debug("Decrypt command: %s" % " ".join(com_list)) @@ -646,9 +701,8 @@ class SecurityContext(object): ntf.seek(0) return ntf.read() - def verify_signature(self, enctext, cert_file=None, cert_type="pem", - node_name=NODE_NAME, node_id=None, id_attr=""): + node_name=NODE_NAME, node_id=None, id_attr=""): """ Verifies the signature of a XML document. :param enctext: The XML document as a string @@ -773,22 +827,22 @@ class SecurityContext(object): must, origdoc) def correctly_signed_authn_query(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "authn_query", must, origdoc) def correctly_signed_logout_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "logout_request", must, origdoc) def correctly_signed_logout_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "logout_response", must, origdoc) def correctly_signed_attribute_query(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "attribute_query", must, origdoc) @@ -799,31 +853,31 @@ class SecurityContext(object): origdoc) def correctly_signed_authz_decision_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "authz_decision_response", must, origdoc) def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "name_id_mapping_request", must, origdoc) def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "name_id_mapping_response", must, origdoc) def correctly_signed_artifact_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "artifact_request", must, origdoc) def correctly_signed_artifact_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "artifact_response", must, origdoc) @@ -835,19 +889,19 @@ class SecurityContext(object): must, origdoc) def correctly_signed_manage_name_id_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "manage_name_id_response", must, origdoc) def correctly_signed_assertion_id_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "assertion_id_request", must, origdoc) def correctly_signed_assertion_id_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "assertion", must, origdoc) @@ -882,18 +936,16 @@ class SecurityContext(object): try: self._check_signature(decoded_xml, assertion, - class_name(assertion), origdoc) + class_name(assertion), origdoc) except Exception, exc: logger.error("correctly_signed_response: %s" % exc) raise return response - #-------------------------------------------------------------------------- # SIGNATURE PART #-------------------------------------------------------------------------- - def sign_statement_using_xmlsec(self, statement, klass_namn, key=None, key_file=None, nodeid=None, id_attr=""): """Sign a SAML statement using xmlsec. @@ -917,7 +969,6 @@ class SecurityContext(object): _, fil = make_temp("%s" % statement, decode=False) - ntf = NamedTemporaryFile() com_list = [self.xmlsec, "--sign", @@ -975,10 +1026,8 @@ class SecurityContext(object): :return: The signed statement """ - return self.sign_statement_using_xmlsec(statement, - class_name(samlp.AttributeQuery()), - key, key_file, nodeid, - id_attr=id_attr) + return self.sign_statement_using_xmlsec(statement, class_name( + samlp.AttributeQuery()), key, key_file, nodeid, id_attr=id_attr) def multiple_signatures(self, statement, to_sign, key=None, key_file=None): """ @@ -991,15 +1040,15 @@ class SecurityContext(object): :param key_file: A file that contains the key to be used :return: A possibly multiple signed statement """ - for (item, id, id_attr) in to_sign: - if not id: + for (item, sid, id_attr) in to_sign: + if not sid: if not item.id: - id = item.id = sid() + sid = item.id = sid() else: - id = item.id + sid = item.id if not item.signature: - item.signature = pre_signature_part(id, self.cert_file) + item.signature = pre_signature_part(sid, self.cert_file) statement = self.sign_statement_using_xmlsec(statement, class_name(item), @@ -1024,32 +1073,30 @@ def pre_signature_part(ident, public_key=None, identifier=None): :return: A preset signature part """ - signature_method = ds.SignatureMethod(algorithm = ds.SIG_RSA_SHA1) + signature_method = ds.SignatureMethod(algorithm=ds.SIG_RSA_SHA1) canonicalization_method = ds.CanonicalizationMethod( - algorithm = ds.ALG_EXC_C14N) - trans0 = ds.Transform(algorithm = ds.TRANSFORM_ENVELOPED) - trans1 = ds.Transform(algorithm = ds.ALG_EXC_C14N) - transforms = ds.Transforms(transform = [trans0, trans1]) - digest_method = ds.DigestMethod(algorithm = ds.DIGEST_SHA1) + algorithm=ds.ALG_EXC_C14N) + trans0 = ds.Transform(algorithm=ds.TRANSFORM_ENVELOPED) + trans1 = ds.Transform(algorithm=ds.ALG_EXC_C14N) + transforms = ds.Transforms(transform=[trans0, trans1]) + digest_method = ds.DigestMethod(algorithm=ds.DIGEST_SHA1) - reference = ds.Reference(uri = "#%s" % ident, - digest_value = ds.DigestValue(), - transforms = transforms, - digest_method = digest_method) + reference = ds.Reference(uri="#%s" % ident, digest_value=ds.DigestValue(), + transforms=transforms, digest_method=digest_method) - signed_info = ds.SignedInfo(signature_method = signature_method, - canonicalization_method = canonicalization_method, - reference = reference) + signed_info = ds.SignedInfo(signature_method=signature_method, + canonicalization_method=canonicalization_method, + reference=reference) - signature = ds.Signature(signed_info=signed_info, - signature_value=ds.SignatureValue()) + signature = ds.Signature(signed_info=signed_info, + signature_value=ds.SignatureValue()) if identifier: signature.id = "Signature%d" % identifier if public_key: - x509_data = ds.X509Data(x509_certificate=[ds.X509Certificate( - text=public_key)]) + x509_data = ds.X509Data( + x509_certificate=[ds.X509Certificate(text=public_key)]) key_info = ds.KeyInfo(x509_data=x509_data) signature.key_info = key_info