diff --git a/.gitignore b/.gitignore index f3351b0..9f6ae01 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,17 @@ example/sp/nocert_sp_conf/sp.xml example/sp/nocert_sp_conf/sp_conf.py example/sp/nocert_sp_conf/who.ini + +example/sp-repoze/my_sp.xml + +example/sp-repoze/pki/localhost.ca.crt + +example/sp-repoze/pki/localhost.ca.key + +example/sp-repoze/sp.xml + +example/sp-repoze/sp.xml + +example/sp-repoze/sp_conf.py + +example/sp-repoze/sp_conf.py diff --git a/example/idp2/idp.py b/example/idp2/idp.py index b1188a9..04843bd 100755 --- a/example/idp2/idp.py +++ b/example/idp2/idp.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import argparse import base64 - +import xmldsig as ds import re import logging import time @@ -24,6 +24,7 @@ from saml2.authn_context import AuthnBroker from saml2.authn_context import PASSWORD from saml2.authn_context import UNSPECIFIED from saml2.authn_context import authn_context_class_ref +from saml2.extension import pefim from saml2.httputil import Response from saml2.httputil import NotFound from saml2.httputil import geturl @@ -38,7 +39,7 @@ from saml2.s_utils import rndstr, exception_trace from saml2.s_utils import UnknownPrincipal from saml2.s_utils import UnsupportedBinding from saml2.s_utils import PolicyError -from saml2.sigver import verify_redirect_signature +from saml2.sigver import verify_redirect_signature, cert_from_instance, encrypt_cert_from_item logger = logging.getLogger("saml2.idp") @@ -125,8 +126,9 @@ class Service(object): return resp(self.environ, self.start_response) else: try: + _encrypt_cert = encrypt_cert_from_item(_dict["req_info"].message) return self.do(_dict["SAMLRequest"], binding, - _dict["RelayState"]) + _dict["RelayState"], encrypt_cert=_encrypt_cert) except KeyError: # Can live with no relay state return self.do(_dict["SAMLRequest"], binding) @@ -151,7 +153,7 @@ class Service(object): resp = Response(http_args["data"], headers=http_args["headers"]) return resp(self.environ, self.start_response) - def do(self, query, binding, relay_state=""): + def do(self, query, binding, relay_state="", encrypt_cert=None): pass def redirect(self): @@ -277,7 +279,7 @@ class SSO(Service): return resp_args, _resp - def do(self, query, binding_in, relay_state=""): + def do(self, query, binding_in, relay_state="", encrypt_cert=None): try: resp_args, _resp = self.verify_request(query, binding_in) except UnknownPrincipal, excp: @@ -297,13 +299,10 @@ class SSO(Service): if REPOZE_ID_EQUIVALENT: identity[REPOZE_ID_EQUIVALENT] = self.user try: - sign_assertion = IDP.config.getattr("sign_assertion", "idp") - if sign_assertion is None: - sign_assertion = False _resp = IDP.create_authn_response( identity, userid=self.user, - authn=AUTHN_BROKER[self.environ["idp.authn_ref"]], sign_assertion=sign_assertion, - sign_response=False, **resp_args) + authn=AUTHN_BROKER[self.environ["idp.authn_ref"]], encrypt_cert=encrypt_cert, + **resp_args) except Exception, excp: logging.error(exception_trace(excp)) resp = ServiceError("Exception: %s" % (excp,)) @@ -537,7 +536,7 @@ def not_found(environ, start_response): # return subject, sp_entity_id class SLO(Service): - def do(self, request, binding, relay_state=""): + def do(self, request, binding, relay_state="", encrypt_cert=None): logger.info("--- Single Log Out Service ---") try: _, body = request.split("\n") @@ -589,7 +588,7 @@ class SLO(Service): class NMI(Service): - def do(self, query, binding, relay_state=""): + def do(self, query, binding, relay_state="", encrypt_cert=None): logger.info("--- Manage Name ID Service ---") req = IDP.parse_manage_name_id_request(query, binding) request = req.message @@ -617,7 +616,7 @@ class NMI(Service): # Only URI binding class AIDR(Service): - def do(self, aid, binding, relay_state=""): + def do(self, aid, binding, relay_state="", encrypt_cert=None): logger.info("--- Assertion ID Service ---") try: @@ -646,7 +645,7 @@ class AIDR(Service): # ---------------------------------------------------------------------------- class ARS(Service): - def do(self, request, binding, relay_state=""): + def do(self, request, binding, relay_state="", encrypt_cert=None): _req = IDP.parse_artifact_resolve(request, binding) msg = IDP.create_artifact_response(_req, _req.artifact.text) @@ -664,7 +663,7 @@ class ARS(Service): # Only SOAP binding class AQS(Service): - def do(self, request, binding, relay_state=""): + def do(self, request, binding, relay_state="", encrypt_cert=None): logger.info("--- Authn Query Service ---") _req = IDP.parse_authn_query(request, binding) _query = _req.message @@ -688,7 +687,7 @@ class AQS(Service): # Only SOAP binding class ATTR(Service): - def do(self, request, binding, relay_state=""): + def do(self, request, binding, relay_state="", encrypt_cert=None): logger.info("--- Attribute Query Service ---") _req = IDP.parse_attribute_query(request, binding) @@ -721,7 +720,7 @@ class ATTR(Service): class NIM(Service): - def do(self, query, binding, relay_state=""): + def do(self, query, binding, relay_state="", encrypt_cert=None): req = IDP.parse_name_id_mapping_request(query, binding) request = req.message # Do the necessary stuff diff --git a/example/sp-repoze/sp.xml b/example/sp-repoze/sp.xml deleted file mode 100644 index 9fbb178..0000000 --- a/example/sp-repoze/sp.xml +++ /dev/null @@ -1,34 +0,0 @@ - -MIIC8jCCAlugAwIBAgIJAJHg2V5J31I8MA0GCSqGSIb3DQEBBQUAMFoxCzAJBgNV -BAYTAlNFMQ0wCwYDVQQHEwRVbWVhMRgwFgYDVQQKEw9VbWVhIFVuaXZlcnNpdHkx -EDAOBgNVBAsTB0lUIFVuaXQxEDAOBgNVBAMTB1Rlc3QgU1AwHhcNMDkxMDI2MTMz -MTE1WhcNMTAxMDI2MTMzMTE1WjBaMQswCQYDVQQGEwJTRTENMAsGA1UEBxMEVW1l -YTEYMBYGA1UEChMPVW1lYSBVbml2ZXJzaXR5MRAwDgYDVQQLEwdJVCBVbml0MRAw -DgYDVQQDEwdUZXN0IFNQMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDkJWP7 -bwOxtH+E15VTaulNzVQ/0cSbM5G7abqeqSNSs0l0veHr6/ROgW96ZeQ57fzVy2MC -FiQRw2fzBs0n7leEmDJyVVtBTavYlhAVXDNa3stgvh43qCfLx+clUlOvtnsoMiiR -mo7qf0BoPKTj7c0uLKpDpEbAHQT4OF1HRYVxMwIDAQABo4G/MIG8MB0GA1UdDgQW -BBQ7RgbMJFDGRBu9o3tDQDuSoBy7JjCBjAYDVR0jBIGEMIGBgBQ7RgbMJFDGRBu9 -o3tDQDuSoBy7JqFepFwwWjELMAkGA1UEBhMCU0UxDTALBgNVBAcTBFVtZWExGDAW -BgNVBAoTD1VtZWEgVW5pdmVyc2l0eTEQMA4GA1UECxMHSVQgVW5pdDEQMA4GA1UE -AxMHVGVzdCBTUIIJAJHg2V5J31I8MAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF -BQADgYEAMuRwwXRnsiyWzmRikpwinnhTmbooKm5TINPE7A7gSQ710RxioQePPhZO -zkM27NnHTrCe2rBVg0EGz7QTd1JIwLPvgoj4VTi/fSha/tXrYUaqc9AqU1kWI4WN -+vffBGQ09mo+6CffuFTZYeOhzP/2stAPwCTU4kxEoiy0KpZMANI= -MIIC8jCCAlugAwIBAgIJAJHg2V5J31I8MA0GCSqGSIb3DQEBBQUAMFoxCzAJBgNV -BAYTAlNFMQ0wCwYDVQQHEwRVbWVhMRgwFgYDVQQKEw9VbWVhIFVuaXZlcnNpdHkx -EDAOBgNVBAsTB0lUIFVuaXQxEDAOBgNVBAMTB1Rlc3QgU1AwHhcNMDkxMDI2MTMz -MTE1WhcNMTAxMDI2MTMzMTE1WjBaMQswCQYDVQQGEwJTRTENMAsGA1UEBxMEVW1l -YTEYMBYGA1UEChMPVW1lYSBVbml2ZXJzaXR5MRAwDgYDVQQLEwdJVCBVbml0MRAw -DgYDVQQDEwdUZXN0IFNQMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDkJWP7 -bwOxtH+E15VTaulNzVQ/0cSbM5G7abqeqSNSs0l0veHr6/ROgW96ZeQ57fzVy2MC -FiQRw2fzBs0n7leEmDJyVVtBTavYlhAVXDNa3stgvh43qCfLx+clUlOvtnsoMiiR -mo7qf0BoPKTj7c0uLKpDpEbAHQT4OF1HRYVxMwIDAQABo4G/MIG8MB0GA1UdDgQW -BBQ7RgbMJFDGRBu9o3tDQDuSoBy7JjCBjAYDVR0jBIGEMIGBgBQ7RgbMJFDGRBu9 -o3tDQDuSoBy7JqFepFwwWjELMAkGA1UEBhMCU0UxDTALBgNVBAcTBFVtZWExGDAW -BgNVBAoTD1VtZWEgVW5pdmVyc2l0eTEQMA4GA1UECxMHSVQgVW5pdDEQMA4GA1UE -AxMHVGVzdCBTUIIJAJHg2V5J31I8MAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF -BQADgYEAMuRwwXRnsiyWzmRikpwinnhTmbooKm5TINPE7A7gSQ710RxioQePPhZO -zkM27NnHTrCe2rBVg0EGz7QTd1JIwLPvgoj4VTi/fSha/tXrYUaqc9AqU1kWI4WN -+vffBGQ09mo+6CffuFTZYeOhzP/2stAPwCTU4kxEoiy0KpZMANI= -Exempel ABExempel ABExample Co.http://www.example.com/rolandJohnSmithjohn.smith@example.com diff --git a/example/sp-repoze/sp_conf.py b/example/sp-repoze/sp_conf.example similarity index 99% rename from example/sp-repoze/sp_conf.py rename to example/sp-repoze/sp_conf.example index d324427..5d244ac 100644 --- a/example/sp-repoze/sp_conf.py +++ b/example/sp-repoze/sp_conf.example @@ -48,4 +48,4 @@ CONFIG = { }, "loglevel": "debug", } -} +} \ No newline at end of file diff --git a/src/s2repoze/plugins/sp.py b/src/s2repoze/plugins/sp.py index 68a26a2..d1cc1cc 100644 --- a/src/s2repoze/plugins/sp.py +++ b/src/s2repoze/plugins/sp.py @@ -27,6 +27,8 @@ import threading import traceback import saml2 from urlparse import parse_qs, urlparse +from saml2.md import Extensions +import xmldsig as ds from StringIO import StringIO @@ -35,6 +37,7 @@ from paste.httpexceptions import HTTPNotImplemented from paste.httpexceptions import HTTPInternalServerError from paste.request import parse_dict_querystring from paste.request import construct_url +from saml2.extension.pefim import SPCertEnc from saml2.httputil import SeeOther from saml2.client_base import ECP_SERVICE from zope.interface import implements @@ -42,7 +45,7 @@ from zope.interface import implements from repoze.who.interfaces import IChallenger, IIdentifier, IAuthenticator from repoze.who.interfaces import IMetadataProvider -from saml2 import ecp, BINDING_HTTP_REDIRECT +from saml2 import ecp, BINDING_HTTP_REDIRECT, element_to_extension_element from saml2 import BINDING_HTTP_POST from saml2.client import Saml2Client @@ -126,7 +129,7 @@ class SAML2Plugin(object): implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider) def __init__(self, rememberer_name, config, saml_client, wayf, cache, - sid_store=None, discovery="", idp_query_param=""): + sid_store=None, discovery="", idp_query_param="", sid_store_cert=None,): self.rememberer_name = rememberer_name self.wayf = wayf self.saml_client = saml_client @@ -143,6 +146,11 @@ class SAML2Plugin(object): self.outstanding_queries = shelve.open(sid_store, writeback=True) else: self.outstanding_queries = {} + if sid_store_cert: + self.outstanding_certs = shelve.open(sid_store_cert, writeback=True) + else: + self.outstanding_certs = {} + self.iam = platform.node() @@ -362,15 +370,30 @@ class SAML2Plugin(object): dest = srvs[0]["location"] logger.debug("destination: %s" % dest) + extensions = None + cert = None + + if _cli.config.generate_cert_func is not None: + cert_str, req_key_str = _cli.config.generate_cert_func() + cert = { + "cert": cert_str, + "key": req_key_str + } + spcertenc = SPCertEnc(x509_data=ds.X509Data(x509_certificate=ds.X509Certificate(text=cert_str))) + extensions = Extensions(extension_elements=[element_to_extension_element(spcertenc)]) + if _cli.authn_requests_signed: _sid = saml2.s_utils.sid(_cli.seed) msg_str = _cli.create_authn_request(dest, vorg=vorg_name, sign=_cli.authn_requests_signed, - message_id=_sid) + message_id=_sid, extensions=extensions) else: - req = _cli.create_authn_request(dest, vorg=vorg_name, sign=False) + req = _cli.create_authn_request(dest, vorg=vorg_name, sign=False, extensions=extensions) msg_str = "%s" % req _sid = req.id + if cert is not None: + self.outstanding_certs[_sid] = cert + ht_args = _cli.apply_binding(_binding, msg_str, destination=dest, relay_state=came_from) logger.debug("ht_args: %s" % ht_args) @@ -417,7 +440,8 @@ class SAML2Plugin(object): # Evaluate the response, returns a AuthnResponse instance try: authresp = self.saml_client.parse_authn_request_response( - post["SAMLResponse"], binding, self.outstanding_queries) + post["SAMLResponse"], binding, self.outstanding_queries, self.outstanding_certs) + except Exception, excp: logger.exception("Exception: %s" % (excp,)) raise diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index 8ea0a0c..efae1f0 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -558,6 +558,8 @@ class SamlBase(ExtensionContainer): except AttributeError: # Backwards compatibility with ET < 1.3 ElementTree._namespace_map[uri] = prefix + except ValueError: + pass return ElementTree.tostring(self._to_element_tree(), encoding="UTF-8") diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 876baa4..b3c53f9 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -122,8 +122,9 @@ class Base(Entity): self.allow_unsolicited = False self.authn_requests_signed = False self.want_assertions_signed = False + self.want_response_signed = False for foo in ["allow_unsolicited", "authn_requests_signed", - "logout_requests_signed", "want_assertions_signed"]: + "logout_requests_signed", "want_assertions_signed", "want_response_signed"]: v = self.config.getattr(foo, "sp") if v is True or v == 'true': setattr(self, foo, True) @@ -234,7 +235,9 @@ class Base(Entity): client_crt = None if "client_crt" in kwargs: client_crt = kwargs["client_crt"] + args = {} + try: args["assertion_consumer_service_url"] = kwargs[ "assertion_consumer_service_urls"][0] @@ -505,7 +508,7 @@ class Base(Entity): # ======== response handling =========== - def parse_authn_request_response(self, xmlstr, binding, outstanding=None): + def parse_authn_request_response(self, xmlstr, binding, outstanding=None, outstanding_certs=None): """ Deal with an AuthnResponse :param xmlstr: The reply as a xml string @@ -525,8 +528,10 @@ class Base(Entity): if xmlstr: kwargs = { "outstanding_queries": outstanding, + "outstanding_certs": outstanding_certs, "allow_unsolicited": self.allow_unsolicited, "want_assertions_signed": self.want_assertions_signed, + "want_response_signed": self.want_response_signed, "return_addrs": self.service_urls(), "entity_id": self.config.entityid, "attribute_converters": self.config.attribute_converters, diff --git a/src/saml2/config.py b/src/saml2/config.py index 094a27f..1a96f16 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -65,7 +65,9 @@ COMMON_ARGS = [ "xmlsec_path", "extension_schemas", "cert_handler_extra_class", + "generate_cert_func", "generate_cert_info", + "verify_encrypt_cert", "tmp_cert_file", "tmp_key_file", "validate_certificate", @@ -78,6 +80,7 @@ SP_ARGS = [ "idp", "aa", "subject_data", + "want_response_signed", "want_assertions_signed", "authn_requests_signed", "name_form", @@ -92,6 +95,8 @@ SP_ARGS = [ AA_IDP_ARGS = [ "sign_assertion", + "sign_response", + "encrypt_assertion", "want_authn_requests_signed", "want_authn_requests_only_with_valid_cert", "provided_attributes", @@ -210,6 +215,8 @@ class Config(object): self.allow_unknown_attributes = False self.extension_schema = {} self.cert_handler_extra_class = None + self.verify_encrypt_cert = None + self.generate_cert_func = None self.generate_cert_info = None self.tmp_cert_file = None self.tmp_key_file = None diff --git a/src/saml2/entity.py b/src/saml2/entity.py index de34222..58243f5 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -1,5 +1,6 @@ import base64 from binascii import hexlify +import copy import logging from hashlib import sha1 from saml2.metadata import ENDPOINTS @@ -19,10 +20,10 @@ from saml2 import soap from saml2 import element_to_extension_element from saml2 import extension_elements_to_elements -from saml2.saml import NameID +from saml2.saml import NameID, EncryptedAssertion from saml2.saml import Issuer from saml2.saml import NAMEID_FORMAT_ENTITY -from saml2.response import LogoutResponse +from saml2.response import LogoutResponse, AuthnResponse from saml2.time_util import instant from saml2.s_utils import sid from saml2.s_utils import UnravelError @@ -31,7 +32,7 @@ from saml2.s_utils import rndstr from saml2.s_utils import success_status_factory from saml2.s_utils import decode_base64_and_inflate from saml2.s_utils import UnsupportedBinding -from saml2.samlp import AuthnRequest, AuthzDecisionQuery, AuthnQuery +from saml2.samlp import AuthnRequest, AuthzDecisionQuery, AuthnQuery, response_from_string from saml2.samlp import AssertionIDRequest from saml2.samlp import ManageNameIDRequest from saml2.samlp import NameIDMappingRequest @@ -49,7 +50,8 @@ from saml2 import VERSION from saml2 import class_name from saml2.config import config_factory from saml2.httpbase import HTTPBase -from saml2.sigver import security_context, response_factory, SigverError +from saml2.sigver import security_context, response_factory, SigverError, CryptoBackendXmlSec1, make_temp, \ + pre_encryption_part from saml2.sigver import pre_signature_part from saml2.sigver import signed_instance_factory from saml2.virtual_org import VirtualOrg @@ -427,7 +429,7 @@ class Entity(HTTPBase): msg.extension_elements = extensions def _response(self, in_response_to, consumer_url=None, status=None, - issuer=None, sign=False, to_sign=None, **kwargs): + issuer=None, sign=False, to_sign=None, encrypt_assertion=False, encrypt_cert=None, **kwargs): """ Create a Response. :param in_response_to: The session identifier of the request @@ -454,10 +456,23 @@ class Entity(HTTPBase): self._add_info(response, **kwargs) + if not sign and to_sign and not encrypt_assertion: + return signed_instance_factory(response, self.sec, to_sign) + + if encrypt_assertion: + sign_class = [(class_name(response), response.id)] + if sign: + response.signature = pre_signature_part(response.id, self.sec.my_cert, 1) + cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) + _, cert_file = make_temp("%s" % encrypt_cert, decode=False) + response = cbxs.encrypt_assertion(response, cert_file, pre_encryption_part())#template(response.assertion.id)) + if sign: + return signed_instance_factory(response, self.sec, sign_class) + else: + return response + if sign: return self.sign(response, to_sign=to_sign) - elif to_sign: - return signed_instance_factory(response, self.sec, to_sign) else: return response @@ -762,7 +777,7 @@ class Entity(HTTPBase): # ------------------------------------------------------------------------ - def _parse_response(self, xmlstr, response_cls, service, binding, **kwargs): + def _parse_response(self, xmlstr, response_cls, service, binding, outstanding_certs=None, **kwargs): """ Deal with a Response :param xmlstr: The response as a xml string @@ -802,11 +817,18 @@ class Entity(HTTPBase): raise xmlstr = self.unravel(xmlstr, binding, response_cls.msgtype) + origxml = xmlstr + if outstanding_certs is not None: + _response = samlp.any_response_from_string(xmlstr) + if len(_response.encrypted_assertion) > 0: + _, cert_file = make_temp("%s" % outstanding_certs[_response.in_response_to]["key"], decode=False) + cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) + xmlstr = cbxs.decrypt(xmlstr, cert_file) if not xmlstr: # Not a valid reponse return None try: - response = response.loads(xmlstr, False) + response = response.loads(xmlstr, False, origxml=origxml) except SigverError, err: logger.error("Signature Error: %s" % err) return None @@ -817,6 +839,14 @@ class Entity(HTTPBase): logger.debug("XMLSTR: %s" % xmlstr) + if hasattr(response.response, 'encrypted_assertion'): + for encrypted_assertion in response.response.encrypted_assertion: + if encrypted_assertion.extension_elements is not None: + assertion_list = extension_elements_to_elements(encrypted_assertion.extension_elements, [saml]) + for assertion in assertion_list: + _assertion = saml.assertion_from_string(str(assertion)) + response.response.assertion.append(_assertion) + if response: response = response.verify() diff --git a/src/saml2/response.py b/src/saml2/response.py index 55cb5ea..b987130 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -268,6 +268,7 @@ class StatusResponse(object): self.in_response_to = None self.signature_check = self.sec.correctly_signed_response self.require_signature = False + self.require_response_signature = False self.not_signed = False self.asynchop = asynchop @@ -318,7 +319,9 @@ class StatusResponse(object): logger.debug("xmlstr: %s" % (self.xmlstr,)) try: - self.response = self.signature_check(xmldata, origdoc=origxml, must=self.require_signature) + self.response = self.signature_check(xmldata, origdoc=origxml, must=self.require_signature, + require_response_signature=self.require_response_signature) + except TypeError: raise except SignatureError: @@ -452,7 +455,7 @@ class AuthnResponse(StatusResponse): return_addrs=None, outstanding_queries=None, timeslack=0, asynchop=True, allow_unsolicited=False, test=False, allow_unknown_attributes=False, - want_assertions_signed=False, **kwargs): + want_assertions_signed=False, want_response_signed=False, **kwargs): StatusResponse.__init__(self, sec_context, return_addrs, timeslack, asynchop=asynchop) @@ -469,6 +472,7 @@ class AuthnResponse(StatusResponse): self.session_not_on_or_after = 0 self.allow_unsolicited = allow_unsolicited self.require_signature = want_assertions_signed + self.require_response_signature = want_response_signed self.test = test self.allow_unknown_attributes = allow_unknown_attributes # diff --git a/src/saml2/server.py b/src/saml2/server.py index c314f9c..89d0fdd 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -25,6 +25,7 @@ import shelve import threading from saml2.eptid import EptidShelve, Eptid +from saml2.saml import EncryptedAssertion from saml2.sdb import SessionStorage from saml2.schema import soapenv @@ -44,7 +45,7 @@ from saml2.request import AuthnQuery from saml2.s_utils import MissingValue, Unknown, rndstr -from saml2.sigver import pre_signature_part, signed_instance_factory +from saml2.sigver import pre_signature_part, signed_instance_factory, CertificateError, CryptoBackendXmlSec1 from saml2.assertion import Assertion from saml2.assertion import Policy @@ -282,7 +283,7 @@ class Server(Entity): sp_entity_id, identity=None, name_id=None, status=None, authn=None, issuer=None, policy=None, sign_assertion=False, sign_response=False, - best_effort=False): + best_effort=False, encrypt_assertion=False, encrypt_cert=None): """ Create a response. A layer of indirection. :param in_response_to: The session identifier of the request @@ -352,7 +353,8 @@ class Server(Entity): self.session_db.store_assertion(assertion, to_sign) return self._response(in_response_to, consumer_url, status, issuer, - sign_response, to_sign, **args) + sign_response, to_sign, encrypt_assertion=encrypt_assertion, + encrypt_cert=encrypt_cert, **args) # ------------------------------------------------------------------------ @@ -425,7 +427,7 @@ class Server(Entity): def create_authn_response(self, identity, in_response_to, destination, sp_entity_id, name_id_policy=None, userid=None, name_id=None, authn=None, issuer=None, - sign_response=False, sign_assertion=False, + sign_response=None, sign_assertion=None, encrypt_cert=None, encrypt_assertion=None, **kwargs): """ Constructs an AuthenticationResponse @@ -454,6 +456,32 @@ class Server(Entity): except KeyError: best_effort = False + if sign_assertion is None: + sign_assertion = self.config.getattr("sign_assertion", "idp") + if sign_assertion is None: + sign_assertion = False + + if sign_response is None: + sign_response = self.config.getattr("sign_response", "idp") + if sign_response is None: + sign_response = False + + if encrypt_assertion is None: + encrypt_assertion = self.config.getattr("encrypt_assertion", "idp") + if encrypt_assertion is None: + encrypt_assertion = False + + if encrypt_assertion: + if encrypt_cert is not None: + verify_encrypt_cert = self.config.getattr("verify_encrypt_cert", "idp") + if verify_encrypt_cert is not None: + if not verify_encrypt_cert(encrypt_cert): + raise CertificateError("Invalid certificate for encryption!") + else: + raise CertificateError("No certificate for encryption!") + else: + encrypt_assertion = False + if not name_id: try: nid_formats = [] @@ -493,8 +521,7 @@ class Server(Entity): try: _authn = authn - - + response = None if (sign_assertion or sign_response) and self.sec.cert_handler.generate_cert(): with self.lock: self.sec.cert_handler.update_cert(True) @@ -508,8 +535,8 @@ class Server(Entity): policy=policy, sign_assertion=sign_assertion, sign_response=sign_response, - best_effort=best_effort) - + best_effort=best_effort, + encrypt_assertion=encrypt_assertion, encrypt_cert=encrypt_cert) return self._authn_response(in_response_to, # in_response_to destination, # consumer_url sp_entity_id, # sp_entity_id @@ -520,7 +547,8 @@ class Server(Entity): policy=policy, sign_assertion=sign_assertion, sign_response=sign_response, - best_effort=best_effort) + best_effort=best_effort, + encrypt_assertion=encrypt_assertion, encrypt_cert=encrypt_cert) except MissingValue, exc: return self.create_error_response(in_response_to, destination, diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 8b30064..0fa5ce4 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -33,12 +33,14 @@ from Crypto.Signature import PKCS1_v1_5 from Crypto.Util.asn1 import DerSequence from Crypto.PublicKey import RSA from saml2.cert import OpenSSLWrapper +from saml2.extension import pefim from saml2.saml import EncryptedAssertion -from saml2.samlp import Response +from saml2.samlp import Response, response_from_string import xmldsig as ds +import xmlenc as enc -from saml2 import samlp, SAMLError +from saml2 import samlp, SAMLError, extension_elements_to_elements from saml2 import class_name from saml2 import saml from saml2 import ExtensionElement @@ -765,9 +767,7 @@ class CryptoBackendXmlSec1(CryptoBackend): "--session-key", key_type, "--xml-data", fil, "--node-xpath", ASSERT_XPATH] - (_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmpl], - exception=EncryptError, - validate_output=False) + (_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmpl], exception=EncryptError, validate_output=False) os.unlink(fil) if not output: @@ -1011,6 +1011,25 @@ def security_context(conf, debug=None): tmp_key_file=conf.tmp_key_file, validate_certificate=conf.validate_certificate) +def encrypt_cert_from_item(item): + _encrypt_cert = None + try: + _elem = extension_elements_to_elements(item.extension_elements[0].children, + [pefim, ds]) + if len(_elem) == 1: + _encrypt_cert = _elem[0].x509_data[0].x509_certificate.text + else: + certs = cert_from_instance(item) + if len(certs) > 0: + _encrypt_cert = certs[0] + if _encrypt_cert is not None: + if _encrypt_cert.find("-----BEGIN CERTIFICATE-----\n") == -1: + _encrypt_cert = "-----BEGIN CERTIFICATE-----\n" + _encrypt_cert + if _encrypt_cert.find("-----END CERTIFICATE-----\n") == -1: + _encrypt_cert = _encrypt_cert + "-----END CERTIFICATE-----\n" + except Exception: + return None + return _encrypt_cert class CertHandlerExtra(object): def __init__(self): @@ -1057,7 +1076,7 @@ class CertHandler(object): self._verify_cert = verify_cert is True self._security_context = security_context self._osw = OpenSSLWrapper() - if key_file is not None: + if key_file is not None and os.path.isfile(key_file): self._key_str = self._osw.read_str_from_file(key_file, key_type) else: self._key_str = "" @@ -1363,99 +1382,114 @@ class SecurityContext(object): only_valid_cert=only_valid_cert) def correctly_signed_authn_request(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False): + origdoc=None, only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "authn_request", must, origdoc, only_valid_cert=only_valid_cert) def correctly_signed_authn_query(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False): + origdoc=None, only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "authn_query", must, origdoc, only_valid_cert) def correctly_signed_logout_request(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False): + origdoc=None, only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "logout_request", must, origdoc, only_valid_cert) def correctly_signed_logout_response(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False): + origdoc=None, only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "logout_response", must, origdoc, only_valid_cert) def correctly_signed_attribute_query(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False): + origdoc=None, only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "attribute_query", must, origdoc, only_valid_cert) def correctly_signed_authz_decision_query(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "authz_decision_query", must, origdoc, only_valid_cert) def correctly_signed_authz_decision_response(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "authz_decision_response", must, origdoc, only_valid_cert) def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "name_id_mapping_request", must, origdoc, only_valid_cert) def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "name_id_mapping_response", must, origdoc, only_valid_cert) def correctly_signed_artifact_request(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False): + origdoc=None, only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "artifact_request", must, origdoc, only_valid_cert) def correctly_signed_artifact_response(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False): + origdoc=None, only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "artifact_response", must, origdoc, only_valid_cert) def correctly_signed_manage_name_id_request(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "manage_name_id_request", must, origdoc, only_valid_cert) def correctly_signed_manage_name_id_response(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "manage_name_id_response", must, origdoc, only_valid_cert) def correctly_signed_assertion_id_request(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, + **kwargs): return self.correctly_signed_message(decoded_xml, "assertion_id_request", must, origdoc, only_valid_cert) def correctly_signed_assertion_id_response(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False): + only_valid_cert=False, **kwargs): return self.correctly_signed_message(decoded_xml, "assertion", must, origdoc, only_valid_cert) - def correctly_signed_response(self, decoded_xml, must=False, origdoc=None): + def correctly_signed_response(self, decoded_xml, must=False, origdoc=None,only_valid_cert=False, + require_response_signature=False, **kwargs): """ Check if a instance is correctly signed, if we have metadata for the IdP that sent the info use that, if not use the key that are in the message if any. @@ -1473,21 +1507,18 @@ class SecurityContext(object): if response.signature: self._check_signature(decoded_xml, response, class_name(response), origdoc) + elif require_response_signature: + raise SignatureError("Signature missing for response") if isinstance(response, Response) and (response.assertion or response.encrypted_assertion): # Try to find the signing cert in the assertion for assertion in ( response.assertion or response.encrypted_assertion): - if response.encrypted_assertion: - decoded_xml = self.decrypt( - assertion.encrypted_data.to_string()) - assertion = saml.assertion_from_string(decoded_xml) - - if not assertion.signature: + if not hasattr(assertion, 'signature') or not assertion.signature: logger.debug("unsigned") if must: - raise SignatureError("Signature missing") + raise SignatureError("Signature missing for assertion") continue else: logger.debug("signed") diff --git a/tests/_test_80_p11_backend.py b/tests/_test_80_p11_backend.py index 4b36a9b..8dcca33 100644 --- a/tests/_test_80_p11_backend.py +++ b/tests/_test_80_p11_backend.py @@ -24,7 +24,7 @@ from saml2 import time_util from saml2 import saml from saml2.s_utils import factory, do_attribute_statement -xmlsec = pytest.importorskip("xmlsec") +#xmlsec = pytest.importorskip("xmlsec") def _find_alts(alts): for a in alts: @@ -57,6 +57,12 @@ class FakeConfig(): self.cert_file = pub_key self.key_file = "pkcs11://%s:0/test?pin=secret1" % P11_MODULE self.debug = False + self.cert_handler_extra_class = None + self.generate_cert_info = False + self.generate_cert_info = False + self.tmp_cert_file = None + self.tmp_key_file = None + self.validate_certificate = False class TestPKCS11(): @@ -173,6 +179,7 @@ class TestPKCS11(): #print "env SOFTHSM_CONF=%s " % softhsm_conf +" ".join(args) logging.debug("Environment {!r}".format(env)) logging.debug("Executing {!r}".format(args)) + args = ['ls'] proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) out, err = proc.communicate() if err is not None and len(err) > 0: diff --git a/tests/pki/cert.crt b/tests/pki/cert.crt new file mode 100644 index 0000000..0db9f90 --- /dev/null +++ b/tests/pki/cert.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMjCCAZsCAQEwDQYJKoZIhvcNAQELBQAwYjELMAkGA1UEBhMCcXcxDzANBgNV +BAgTBnF3ZXJ0eTEPMA0GA1UEBxMGcXdlcnR5MQ8wDQYDVQQKEwZxd2VydHkxDzAN +BgNVBAsTBnF3ZXJ0eTEPMA0GA1UEAxMGcXdlcnR5MB4XDTE0MDIwNDA4NTY0N1oX +DTE0MDIwNTA4NTY0N1owYTELMAkGA1UEBhMCYXMxDzANBgNVBAgTBmFzZGZnaDEP +MA0GA1UEBxMGYXNkZmdoMQ8wDQYDVQQKEwZhc2RmZ2gxDjAMBgNVBAsTBWFzZGZn +MQ8wDQYDVQQDEwZhc2RmZ2gwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKUZ +NSY8xkonIyXOMPwK2aRpkAxnM9/0l15aJ0DCrufCzMWZF+oqne1RZXNZbH4VOqwh +CSq9iXk5ONe0RqIlxvvVvFO+jx6/4vulhEnhOXyrYTJ8AJhBoFRRkJ8VMgN1Hf0C +wIhVqVQpe6UIOkg17PCXoYjicLTeV+CQrrKEajghAgMBAAEwDQYJKoZIhvcNAQEL +BQADgYEAJTspJ8fDcTXlAM0Rgr73EVyhDpIN1MC5hUFay0YrLenOuXaNH9rFzg8j +AdsB5+N6KJg7JB4+oqbucgz9/poqrKUG9amg/uv87vjMR7O7xtlKXt1iSLOdu/uU +cYhtRVSlwRaVfhd6fiYylJag8ujraUmPbqmvM8y23QL5l+O3Nng= +-----END CERTIFICATE----- diff --git a/tests/test_40_sigver.py b/tests/test_40_sigver.py index 3862b62..89a8540 100644 --- a/tests/test_40_sigver.py +++ b/tests/test_40_sigver.py @@ -96,6 +96,7 @@ class FakeConfig(): key_file = PRIV_KEY debug = False cert_handler_extra_class = None + generate_cert_func = None generate_cert_info = False tmp_cert_file = None tmp_key_file = None diff --git a/tests/test_81_certificates.py b/tests/test_81_certificates.py index e07879c..b8d4898 100644 --- a/tests/test_81_certificates.py +++ b/tests/test_81_certificates.py @@ -1,4 +1,5 @@ from os import remove +import os import time __author__ = 'haho0032' @@ -32,7 +33,7 @@ class TestGenerateCertificates(unittest.TestCase): ca_cert, ca_key = osw.create_certificate(cert_info_ca, request=False, write_to_file=True, - cert_dir="pki") + cert_dir=os.path.dirname(os.path.abspath(__file__)) + "/pki") req_cert_str, req_key_str = osw.create_certificate(cert_info, request=True)