From 74d87d417ebd7927f153a526bac3c37d140702e9 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 6 Mar 2014 19:43:06 +0100 Subject: [PATCH] Cleaned up a bit --- src/saml2/sigver.py | 170 +++++++++++++++++++++++++++----------------- 1 file changed, 104 insertions(+), 66 deletions(-) diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 193bf6f..287e51e 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -52,8 +52,16 @@ from saml2.time_util import str_to_time from tempfile import NamedTemporaryFile from subprocess import Popen, PIPE -from xmlenc import EncryptionMethod, EncryptedKey, CipherData, CipherValue, \ - EncryptedData +from xmlenc import EncryptionMethod +from xmlenc import EncryptedKey +from xmlenc import CipherData +from xmlenc import CipherValue +from xmlenc import EncryptedData + +from Crypto.Hash import SHA256 +from Crypto.Hash import SHA384 +from Crypto.Hash import SHA512 +from Crypto.Hash import SHA logger = logging.getLogger(__name__) @@ -63,7 +71,6 @@ RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1" RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5" TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc" -from Crypto.Hash import SHA256, SHA384, SHA512, SHA class SigverError(SAMLError): @@ -925,12 +932,14 @@ def security_context(conf, debug=None): raise SigverError('Unknown crypto_backend %s' % ( repr(conf.crypto_backend))) - return SecurityContext(crypto, conf.key_file, - cert_file=conf.cert_file, metadata=metadata, - debug=debug, only_use_keys_in_metadata=_only_md, - cert_handler_extra_class=conf.cert_handler_extra_class, - generate_cert_info=conf.generate_cert_info, tmp_cert_file=conf.tmp_cert_file, - tmp_key_file=conf.tmp_key_file, validate_certificate=conf.validate_certificate) + return SecurityContext( + crypto, conf.key_file, cert_file=conf.cert_file, metadata=metadata, + debug=debug, only_use_keys_in_metadata=_only_md, + cert_handler_extra_class=conf.cert_handler_extra_class, + generate_cert_info=conf.generate_cert_info, + tmp_cert_file=conf.tmp_cert_file, + tmp_key_file=conf.tmp_key_file, + validate_certificate=conf.validate_certificate) class CertHandlerExtra(object): @@ -940,7 +949,8 @@ class CertHandlerExtra(object): def use_generate_cert_func(self): raise Exception("use_generate_cert_func function must be implemented") - def generate_cert(self, generate_cert_info, root_cert_string, root_key_string): + def generate_cert(self, generate_cert_info, root_cert_string, + root_key_string): raise Exception("generate_cert function must be implemented") #Excepts to return (cert_string, key_string) @@ -953,12 +963,14 @@ class CertHandlerExtra(object): class CertHandler(object): - def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=None, key_type="pem", - generate_cert_info=None, cert_handler_extra_class=None, tmp_cert_file=None, tmp_key_file=None, - verify_cert=False): + def __init__(self, security_context, cert_file=None, cert_type="pem", + key_file=None, key_type="pem", generate_cert_info=None, + cert_handler_extra_class=None, tmp_cert_file=None, + tmp_key_file=None, verify_cert=False): """ - Initiates the class for handling certificates. Enables the certificates to either be a single certificate - as base functionality or makes it possible to generate a new certificate for each call to the function. + Initiates the class for handling certificates. Enables the certificates + to either be a single certificate as base functionality or makes it + possible to generate a new certificate for each call to the function. :param key_file: :param key_type: :param cert_file: @@ -968,7 +980,9 @@ class CertHandler(object): """ self._verify_cert = False self._generate_cert = False - self._last_cert_verified = None #This cert do not have to be valid, it is just the last cert to be validated. + #This cert do not have to be valid, it is just the last cert to be + # validated. + self._last_cert_verified = None if cert_type == "pem" and key_type == "pem": self._verify_cert = verify_cert is True self._security_context = security_context @@ -978,7 +992,8 @@ class CertHandler(object): else: self._key_str = "" if cert_file is not None: - self._cert_str = self._osw.read_str_from_file(cert_file, cert_type) + self._cert_str = self._osw.read_str_from_file(cert_file, + cert_type) else: self._cert_str = "" @@ -989,8 +1004,9 @@ class CertHandler(object): self._cert_info = None self._generate_cert_func_active = False - if generate_cert_info is not None and len(self._cert_str) > 0 and len(self._key_str) > 0 \ - and tmp_key_file is not None and tmp_cert_file is not None: + if generate_cert_info is not None and len(self._cert_str) > 0 and \ + len(self._key_str) > 0 and tmp_key_file is not \ + None and tmp_cert_file is not None: self._generate_cert = True self._cert_info = generate_cert_info self._cert_handler_extra_class = cert_handler_extra_class @@ -999,8 +1015,10 @@ class CertHandler(object): if self._verify_cert: cert_str = self._osw.read_str_from_file(cert_file, "pem") self._last_validated_cert = cert_str - if self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_validate_cert_func(): - self._cert_handler_extra_class.validate_cert(cert_str, self._cert_str, self._key_str) + if self._cert_handler_extra_class is not None and \ + self._cert_handler_extra_class.use_validate_cert_func(): + self._cert_handler_extra_class.validate_cert( + cert_str, self._cert_str, self._key_str) else: valid, mess = self._osw.verify(self._cert_str, cert_str) logger.info("CertHandler.verify_cert: %s" % mess) @@ -1016,22 +1034,25 @@ class CertHandler(object): self._tmp_cert_str = client_crt #No private key for signing self._tmp_key_str = "" - elif self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_generate_cert_func(): + elif self._cert_handler_extra_class is not None and \ + self._cert_handler_extra_class.use_generate_cert_func(): (self._tmp_cert_str, self._tmp_key_str) = \ self._cert_handler_extra_class.generate_cert(self._cert_info, self._cert_str, self._key_str) else: self._tmp_cert_str, self._tmp_key_str = self._osw.create_certificate(self._cert_info, request=True) - self._tmp_cert_str = self._osw.create_cert_signed_certificate(self._cert_str, self._key_str, - self._tmp_cert_str) - valid, mess = self._osw.verify(self._cert_str, self._tmp_cert_str) + self._tmp_cert_str = self._osw.create_cert_signed_certificate( + self._cert_str, self._key_str, self._tmp_cert_str) + valid, mess = self._osw.verify(self._cert_str, + self._tmp_cert_str) self._osw.write_str_to_file(self._tmp_cert_file, self._tmp_cert_str) self._osw.write_str_to_file(self._tmp_key_file, self._tmp_key_str) self._security_context.key_file = self._tmp_key_file self._security_context.cert_file = self._tmp_cert_file self._security_context.key_type = "pem" self._security_context.cert_type = "pem" - self._security_context.my_cert = read_cert_from_file(self._security_context.cert_file, - self._security_context.cert_type) + self._security_context.my_cert = read_cert_from_file( + self._security_context.cert_file, + self._security_context.cert_type) # How to get a rsa pub key fingerprint from a certificate @@ -1043,8 +1064,9 @@ class SecurityContext(object): def __init__(self, crypto, key_file="", key_type="pem", cert_file="", cert_type="pem", metadata=None, debug=False, template="", encrypt_key_type="des-192", - only_use_keys_in_metadata=False, cert_handler_extra_class=None, generate_cert_info=None, - tmp_cert_file=None, tmp_key_file=None, validate_certificate=None): + only_use_keys_in_metadata=False, cert_handler_extra_class=None, + generate_cert_info=None, tmp_cert_file=None, + tmp_key_file=None, validate_certificate=None): self.crypto = crypto assert (isinstance(self.crypto, CryptoBackend)) @@ -1059,8 +1081,10 @@ class SecurityContext(object): self.my_cert = read_cert_from_file(cert_file, cert_type) - self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, key_type, generate_cert_info, - cert_handler_extra_class, tmp_cert_file, tmp_key_file, validate_certificate) + self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, + key_type, generate_cert_info, + cert_handler_extra_class, tmp_cert_file, + tmp_key_file, validate_certificate) self.cert_handler.update_cert(True) @@ -1135,7 +1159,8 @@ class SecurityContext(object): ) def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, - origdoc=None, id_attr="", must=False, only_valid_cert=False): + origdoc=None, id_attr="", must=False, + only_valid_cert=False): #print item try: issuer = item.issuer.text.strip() @@ -1179,13 +1204,15 @@ class SecurityContext(object): try: if self.verify_signature(origdoc, pem_file, node_name=node_name, - node_id=item.id, id_attr=id_attr): + node_id=item.id, + id_attr=id_attr): verified = True break except Exception: if self.verify_signature(decoded_xml, pem_file, node_name=node_name, - node_id=item.id, id_attr=id_attr): + node_id=item.id, + id_attr=id_attr): verified = True break else: @@ -1247,91 +1274,101 @@ class SecurityContext(object): return msg return self._check_signature(decoded_xml, msg, class_name(msg), - origdoc, must=must, only_valid_cert=only_valid_cert) + origdoc, must=must, + only_valid_cert=only_valid_cert) def correctly_signed_authn_request(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "authn_request", - must, origdoc, only_valid_cert=only_valid_cert) + must, origdoc, + only_valid_cert=only_valid_cert) def correctly_signed_authn_query(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "authn_query", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_logout_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "logout_request", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_logout_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "logout_response", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_attribute_query(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "attribute_query", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_authz_decision_query(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "authz_decision_query", must, - origdoc) + origdoc, only_valid_cert) def correctly_signed_authz_decision_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "authz_decision_response", must, - origdoc) + origdoc, only_valid_cert) def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "name_id_mapping_request", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "name_id_mapping_response", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_artifact_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "artifact_request", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_artifact_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "artifact_response", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_manage_name_id_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "manage_name_id_request", - must, origdoc) + must, origdoc, only_valid_cert) def correctly_signed_manage_name_id_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "manage_name_id_response", must, - origdoc) + origdoc, only_valid_cert) def correctly_signed_assertion_id_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "assertion_id_request", must, - origdoc) + origdoc, only_valid_cert) def correctly_signed_assertion_id_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None, + only_valid_cert=False): return self.correctly_signed_message(decoded_xml, "assertion", must, - origdoc) + origdoc, only_valid_cert) def correctly_signed_response(self, decoded_xml, must=False, origdoc=None): """ Check if a instance is correctly signed, if we have metadata for @@ -1353,11 +1390,12 @@ class SecurityContext(object): origdoc) if isinstance(response, Response) and (response.assertion or - response.encrypted_assertion): + response.encrypted_assertion): # Try to find the signing cert in the assertion for assertion in (response.assertion or response.encrypted_assertion): if response.encrypted_assertion: - decoded_xml = self.decrypt(assertion.encrypted_data.to_string()) + decoded_xml = self.decrypt( + assertion.encrypted_data.to_string()) assertion = saml.assertion_from_string(decoded_xml) if not assertion.signature: