Cleaned up a bit
This commit is contained in:
		@@ -52,8 +52,16 @@ from saml2.time_util import str_to_time
 | 
			
		||||
 | 
			
		||||
from tempfile import NamedTemporaryFile
 | 
			
		||||
from subprocess import Popen, PIPE
 | 
			
		||||
from xmlenc import EncryptionMethod, EncryptedKey, CipherData, CipherValue, \
 | 
			
		||||
    EncryptedData
 | 
			
		||||
from xmlenc import EncryptionMethod
 | 
			
		||||
from xmlenc import EncryptedKey
 | 
			
		||||
from xmlenc import CipherData
 | 
			
		||||
from xmlenc import CipherValue
 | 
			
		||||
from xmlenc import EncryptedData
 | 
			
		||||
 | 
			
		||||
from Crypto.Hash import SHA256
 | 
			
		||||
from Crypto.Hash import SHA384
 | 
			
		||||
from Crypto.Hash import SHA512
 | 
			
		||||
from Crypto.Hash import SHA
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@@ -63,7 +71,6 @@ RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
 | 
			
		||||
RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5"
 | 
			
		||||
TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc"
 | 
			
		||||
 | 
			
		||||
from Crypto.Hash import SHA256, SHA384, SHA512, SHA
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SigverError(SAMLError):
 | 
			
		||||
@@ -925,12 +932,14 @@ def security_context(conf, debug=None):
 | 
			
		||||
        raise SigverError('Unknown crypto_backend %s' % (
 | 
			
		||||
            repr(conf.crypto_backend)))
 | 
			
		||||
 | 
			
		||||
    return SecurityContext(crypto, conf.key_file,
 | 
			
		||||
                           cert_file=conf.cert_file, metadata=metadata,
 | 
			
		||||
    return SecurityContext(
 | 
			
		||||
        crypto, conf.key_file, cert_file=conf.cert_file, metadata=metadata,
 | 
			
		||||
        debug=debug, only_use_keys_in_metadata=_only_md,
 | 
			
		||||
        cert_handler_extra_class=conf.cert_handler_extra_class,
 | 
			
		||||
                           generate_cert_info=conf.generate_cert_info, tmp_cert_file=conf.tmp_cert_file,
 | 
			
		||||
                           tmp_key_file=conf.tmp_key_file, validate_certificate=conf.validate_certificate)
 | 
			
		||||
        generate_cert_info=conf.generate_cert_info,
 | 
			
		||||
        tmp_cert_file=conf.tmp_cert_file,
 | 
			
		||||
        tmp_key_file=conf.tmp_key_file,
 | 
			
		||||
        validate_certificate=conf.validate_certificate)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CertHandlerExtra(object):
 | 
			
		||||
@@ -940,7 +949,8 @@ class CertHandlerExtra(object):
 | 
			
		||||
    def use_generate_cert_func(self):
 | 
			
		||||
        raise Exception("use_generate_cert_func function must be implemented")
 | 
			
		||||
 | 
			
		||||
    def generate_cert(self, generate_cert_info, root_cert_string, root_key_string):
 | 
			
		||||
    def generate_cert(self, generate_cert_info, root_cert_string,
 | 
			
		||||
                      root_key_string):
 | 
			
		||||
        raise Exception("generate_cert function must be implemented")
 | 
			
		||||
        #Excepts to return (cert_string, key_string)
 | 
			
		||||
 | 
			
		||||
@@ -953,12 +963,14 @@ class CertHandlerExtra(object):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CertHandler(object):
 | 
			
		||||
    def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=None, key_type="pem",
 | 
			
		||||
                 generate_cert_info=None, cert_handler_extra_class=None, tmp_cert_file=None, tmp_key_file=None,
 | 
			
		||||
                 verify_cert=False):
 | 
			
		||||
    def __init__(self, security_context, cert_file=None, cert_type="pem",
 | 
			
		||||
                 key_file=None, key_type="pem", generate_cert_info=None,
 | 
			
		||||
                 cert_handler_extra_class=None, tmp_cert_file=None,
 | 
			
		||||
                 tmp_key_file=None, verify_cert=False):
 | 
			
		||||
        """
 | 
			
		||||
        Initiates the class for handling certificates. Enables the certificates to either be a single certificate
 | 
			
		||||
        as base functionality or makes it possible to generate a new certificate for each call to the function.
 | 
			
		||||
        Initiates the class for handling certificates. Enables the certificates
 | 
			
		||||
        to either be a single certificate as base functionality or makes it
 | 
			
		||||
        possible to generate a new certificate for each call to the function.
 | 
			
		||||
        :param key_file:
 | 
			
		||||
        :param key_type:
 | 
			
		||||
        :param cert_file:
 | 
			
		||||
@@ -968,7 +980,9 @@ class CertHandler(object):
 | 
			
		||||
        """
 | 
			
		||||
        self._verify_cert = False
 | 
			
		||||
        self._generate_cert = False
 | 
			
		||||
        self._last_cert_verified = None #This cert do not have to be valid, it is just the last cert to be validated.
 | 
			
		||||
        #This cert do not have to be valid, it is just the last cert to be
 | 
			
		||||
        # validated.
 | 
			
		||||
        self._last_cert_verified = None
 | 
			
		||||
        if cert_type == "pem" and key_type == "pem":
 | 
			
		||||
            self._verify_cert = verify_cert is True
 | 
			
		||||
            self._security_context = security_context
 | 
			
		||||
@@ -978,7 +992,8 @@ class CertHandler(object):
 | 
			
		||||
            else:
 | 
			
		||||
                self._key_str = ""
 | 
			
		||||
            if cert_file is not None:
 | 
			
		||||
                self._cert_str = self._osw.read_str_from_file(cert_file, cert_type)
 | 
			
		||||
                self._cert_str = self._osw.read_str_from_file(cert_file,
 | 
			
		||||
                                                              cert_type)
 | 
			
		||||
            else:
 | 
			
		||||
                self._cert_str = ""
 | 
			
		||||
 | 
			
		||||
@@ -989,8 +1004,9 @@ class CertHandler(object):
 | 
			
		||||
 | 
			
		||||
            self._cert_info = None
 | 
			
		||||
            self._generate_cert_func_active = False
 | 
			
		||||
            if generate_cert_info is not None and len(self._cert_str) > 0 and len(self._key_str) > 0 \
 | 
			
		||||
               and tmp_key_file is not None and tmp_cert_file is not None:
 | 
			
		||||
            if generate_cert_info is not None and len(self._cert_str) > 0 and \
 | 
			
		||||
                    len(self._key_str) > 0 and tmp_key_file is not \
 | 
			
		||||
                    None and tmp_cert_file is not None:
 | 
			
		||||
                self._generate_cert = True
 | 
			
		||||
                self._cert_info = generate_cert_info
 | 
			
		||||
                self._cert_handler_extra_class = cert_handler_extra_class
 | 
			
		||||
@@ -999,8 +1015,10 @@ class CertHandler(object):
 | 
			
		||||
        if self._verify_cert:
 | 
			
		||||
            cert_str = self._osw.read_str_from_file(cert_file, "pem")
 | 
			
		||||
            self._last_validated_cert = cert_str
 | 
			
		||||
            if self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_validate_cert_func():
 | 
			
		||||
                self._cert_handler_extra_class.validate_cert(cert_str, self._cert_str, self._key_str)
 | 
			
		||||
            if self._cert_handler_extra_class is not None and \
 | 
			
		||||
                    self._cert_handler_extra_class.use_validate_cert_func():
 | 
			
		||||
                self._cert_handler_extra_class.validate_cert(
 | 
			
		||||
                    cert_str, self._cert_str, self._key_str)
 | 
			
		||||
            else:
 | 
			
		||||
                valid, mess = self._osw.verify(self._cert_str, cert_str)
 | 
			
		||||
                logger.info("CertHandler.verify_cert: %s" % mess)
 | 
			
		||||
@@ -1016,21 +1034,24 @@ class CertHandler(object):
 | 
			
		||||
                self._tmp_cert_str = client_crt
 | 
			
		||||
                #No private key for signing
 | 
			
		||||
                self._tmp_key_str = ""
 | 
			
		||||
            elif self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_generate_cert_func():
 | 
			
		||||
            elif self._cert_handler_extra_class is not None and \
 | 
			
		||||
                    self._cert_handler_extra_class.use_generate_cert_func():
 | 
			
		||||
                (self._tmp_cert_str, self._tmp_key_str) = \
 | 
			
		||||
                    self._cert_handler_extra_class.generate_cert(self._cert_info, self._cert_str, self._key_str)
 | 
			
		||||
            else:
 | 
			
		||||
                self._tmp_cert_str, self._tmp_key_str = self._osw.create_certificate(self._cert_info, request=True)
 | 
			
		||||
                self._tmp_cert_str = self._osw.create_cert_signed_certificate(self._cert_str, self._key_str,
 | 
			
		||||
                self._tmp_cert_str = self._osw.create_cert_signed_certificate(
 | 
			
		||||
                    self._cert_str, self._key_str, self._tmp_cert_str)
 | 
			
		||||
                valid, mess = self._osw.verify(self._cert_str,
 | 
			
		||||
                                               self._tmp_cert_str)
 | 
			
		||||
                valid, mess = self._osw.verify(self._cert_str, self._tmp_cert_str)
 | 
			
		||||
            self._osw.write_str_to_file(self._tmp_cert_file, self._tmp_cert_str)
 | 
			
		||||
            self._osw.write_str_to_file(self._tmp_key_file, self._tmp_key_str)
 | 
			
		||||
            self._security_context.key_file = self._tmp_key_file
 | 
			
		||||
            self._security_context.cert_file = self._tmp_cert_file
 | 
			
		||||
            self._security_context.key_type = "pem"
 | 
			
		||||
            self._security_context.cert_type = "pem"
 | 
			
		||||
            self._security_context.my_cert = read_cert_from_file(self._security_context.cert_file,
 | 
			
		||||
            self._security_context.my_cert = read_cert_from_file(
 | 
			
		||||
                self._security_context.cert_file,
 | 
			
		||||
                self._security_context.cert_type)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -1043,8 +1064,9 @@ class SecurityContext(object):
 | 
			
		||||
    def __init__(self, crypto, key_file="", key_type="pem",
 | 
			
		||||
                 cert_file="", cert_type="pem", metadata=None,
 | 
			
		||||
                 debug=False, template="", encrypt_key_type="des-192",
 | 
			
		||||
                 only_use_keys_in_metadata=False, cert_handler_extra_class=None, generate_cert_info=None,
 | 
			
		||||
                 tmp_cert_file=None, tmp_key_file=None, validate_certificate=None):
 | 
			
		||||
                 only_use_keys_in_metadata=False, cert_handler_extra_class=None,
 | 
			
		||||
                 generate_cert_info=None, tmp_cert_file=None,
 | 
			
		||||
                 tmp_key_file=None, validate_certificate=None):
 | 
			
		||||
 | 
			
		||||
        self.crypto = crypto
 | 
			
		||||
        assert (isinstance(self.crypto, CryptoBackend))
 | 
			
		||||
@@ -1059,8 +1081,10 @@ class SecurityContext(object):
 | 
			
		||||
 | 
			
		||||
        self.my_cert = read_cert_from_file(cert_file, cert_type)
 | 
			
		||||
 | 
			
		||||
        self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, key_type, generate_cert_info,
 | 
			
		||||
                                        cert_handler_extra_class, tmp_cert_file, tmp_key_file, validate_certificate)
 | 
			
		||||
        self.cert_handler = CertHandler(self, cert_file, cert_type, key_file,
 | 
			
		||||
                                        key_type, generate_cert_info,
 | 
			
		||||
                                        cert_handler_extra_class, tmp_cert_file,
 | 
			
		||||
                                        tmp_key_file, validate_certificate)
 | 
			
		||||
 | 
			
		||||
        self.cert_handler.update_cert(True)
 | 
			
		||||
 | 
			
		||||
@@ -1135,7 +1159,8 @@ class SecurityContext(object):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_signature(self, decoded_xml, item, node_name=NODE_NAME,
 | 
			
		||||
                         origdoc=None, id_attr="", must=False, only_valid_cert=False):
 | 
			
		||||
                         origdoc=None, id_attr="", must=False,
 | 
			
		||||
                         only_valid_cert=False):
 | 
			
		||||
        #print item
 | 
			
		||||
        try:
 | 
			
		||||
            issuer = item.issuer.text.strip()
 | 
			
		||||
@@ -1179,13 +1204,15 @@ class SecurityContext(object):
 | 
			
		||||
                    try:
 | 
			
		||||
                        if self.verify_signature(origdoc, pem_file,
 | 
			
		||||
                                                 node_name=node_name,
 | 
			
		||||
                                                 node_id=item.id, id_attr=id_attr):
 | 
			
		||||
                                                 node_id=item.id,
 | 
			
		||||
                                                 id_attr=id_attr):
 | 
			
		||||
                            verified = True
 | 
			
		||||
                            break
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        if self.verify_signature(decoded_xml, pem_file,
 | 
			
		||||
                                                 node_name=node_name,
 | 
			
		||||
                                                 node_id=item.id, id_attr=id_attr):
 | 
			
		||||
                                                 node_id=item.id,
 | 
			
		||||
                                                 id_attr=id_attr):
 | 
			
		||||
                            verified = True
 | 
			
		||||
                            break
 | 
			
		||||
                else:
 | 
			
		||||
@@ -1247,91 +1274,101 @@ class SecurityContext(object):
 | 
			
		||||
                return msg
 | 
			
		||||
 | 
			
		||||
        return self._check_signature(decoded_xml, msg, class_name(msg),
 | 
			
		||||
                                     origdoc, must=must, only_valid_cert=only_valid_cert)
 | 
			
		||||
                                     origdoc, must=must,
 | 
			
		||||
                                     only_valid_cert=only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_authn_request(self, decoded_xml, must=False,
 | 
			
		||||
                                       origdoc=None, only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml, "authn_request",
 | 
			
		||||
                                             must, origdoc, only_valid_cert=only_valid_cert)
 | 
			
		||||
                                             must, origdoc,
 | 
			
		||||
                                             only_valid_cert=only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_authn_query(self, decoded_xml, must=False,
 | 
			
		||||
                                     origdoc=None):
 | 
			
		||||
                                     origdoc=None, only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml, "authn_query",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_logout_request(self, decoded_xml, must=False,
 | 
			
		||||
                                        origdoc=None):
 | 
			
		||||
                                        origdoc=None, only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml, "logout_request",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_logout_response(self, decoded_xml, must=False,
 | 
			
		||||
                                         origdoc=None):
 | 
			
		||||
                                         origdoc=None, only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml, "logout_response",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_attribute_query(self, decoded_xml, must=False,
 | 
			
		||||
                                         origdoc=None):
 | 
			
		||||
                                         origdoc=None, only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml, "attribute_query",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_authz_decision_query(self, decoded_xml, must=False,
 | 
			
		||||
                                              origdoc=None):
 | 
			
		||||
                                              origdoc=None,
 | 
			
		||||
                                              only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "authz_decision_query", must,
 | 
			
		||||
                                             origdoc)
 | 
			
		||||
                                             origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_authz_decision_response(self, decoded_xml, must=False,
 | 
			
		||||
                                                 origdoc=None):
 | 
			
		||||
                                                 origdoc=None,
 | 
			
		||||
                                                 only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "authz_decision_response", must,
 | 
			
		||||
                                             origdoc)
 | 
			
		||||
                                             origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False,
 | 
			
		||||
                                                 origdoc=None):
 | 
			
		||||
                                                 origdoc=None,
 | 
			
		||||
                                                 only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "name_id_mapping_request",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False,
 | 
			
		||||
                                                  origdoc=None):
 | 
			
		||||
                                                  origdoc=None,
 | 
			
		||||
                                                  only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "name_id_mapping_response",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_artifact_request(self, decoded_xml, must=False,
 | 
			
		||||
                                          origdoc=None):
 | 
			
		||||
                                          origdoc=None, only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "artifact_request",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_artifact_response(self, decoded_xml, must=False,
 | 
			
		||||
                                           origdoc=None):
 | 
			
		||||
                                           origdoc=None, only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "artifact_response",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_manage_name_id_request(self, decoded_xml, must=False,
 | 
			
		||||
                                                origdoc=None):
 | 
			
		||||
                                                origdoc=None,
 | 
			
		||||
                                                only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "manage_name_id_request",
 | 
			
		||||
                                             must, origdoc)
 | 
			
		||||
                                             must, origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_manage_name_id_response(self, decoded_xml, must=False,
 | 
			
		||||
                                                 origdoc=None):
 | 
			
		||||
                                                 origdoc=None,
 | 
			
		||||
                                                 only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "manage_name_id_response", must,
 | 
			
		||||
                                             origdoc)
 | 
			
		||||
                                             origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_assertion_id_request(self, decoded_xml, must=False,
 | 
			
		||||
                                              origdoc=None):
 | 
			
		||||
                                              origdoc=None,
 | 
			
		||||
                                              only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml,
 | 
			
		||||
                                             "assertion_id_request", must,
 | 
			
		||||
                                             origdoc)
 | 
			
		||||
                                             origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_assertion_id_response(self, decoded_xml, must=False,
 | 
			
		||||
                                               origdoc=None):
 | 
			
		||||
                                               origdoc=None,
 | 
			
		||||
                                               only_valid_cert=False):
 | 
			
		||||
        return self.correctly_signed_message(decoded_xml, "assertion", must,
 | 
			
		||||
                                             origdoc)
 | 
			
		||||
                                             origdoc, only_valid_cert)
 | 
			
		||||
 | 
			
		||||
    def correctly_signed_response(self, decoded_xml, must=False, origdoc=None):
 | 
			
		||||
        """ Check if a instance is correctly signed, if we have metadata for
 | 
			
		||||
@@ -1357,7 +1394,8 @@ class SecurityContext(object):
 | 
			
		||||
            # Try to find the signing cert in the assertion
 | 
			
		||||
            for assertion in (response.assertion or response.encrypted_assertion):
 | 
			
		||||
                if response.encrypted_assertion:
 | 
			
		||||
                    decoded_xml = self.decrypt(assertion.encrypted_data.to_string())
 | 
			
		||||
                    decoded_xml = self.decrypt(
 | 
			
		||||
                        assertion.encrypted_data.to_string())
 | 
			
		||||
                    assertion = saml.assertion_from_string(decoded_xml)
 | 
			
		||||
 | 
			
		||||
                if not assertion.signature:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user