Cleaned up a bit

This commit is contained in:
Roland Hedberg
2014-03-06 19:43:06 +01:00
parent eb12108fb9
commit 74d87d417e

View File

@@ -52,8 +52,16 @@ from saml2.time_util import str_to_time
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from subprocess import Popen, PIPE from subprocess import Popen, PIPE
from xmlenc import EncryptionMethod, EncryptedKey, CipherData, CipherValue, \ from xmlenc import EncryptionMethod
EncryptedData from xmlenc import EncryptedKey
from xmlenc import CipherData
from xmlenc import CipherValue
from xmlenc import EncryptedData
from Crypto.Hash import SHA256
from Crypto.Hash import SHA384
from Crypto.Hash import SHA512
from Crypto.Hash import SHA
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,7 +71,6 @@ RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5" RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5"
TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc" TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc"
from Crypto.Hash import SHA256, SHA384, SHA512, SHA
class SigverError(SAMLError): class SigverError(SAMLError):
@@ -925,12 +932,14 @@ def security_context(conf, debug=None):
raise SigverError('Unknown crypto_backend %s' % ( raise SigverError('Unknown crypto_backend %s' % (
repr(conf.crypto_backend))) repr(conf.crypto_backend)))
return SecurityContext(crypto, conf.key_file, return SecurityContext(
cert_file=conf.cert_file, metadata=metadata, crypto, conf.key_file, cert_file=conf.cert_file, metadata=metadata,
debug=debug, only_use_keys_in_metadata=_only_md, debug=debug, only_use_keys_in_metadata=_only_md,
cert_handler_extra_class=conf.cert_handler_extra_class, cert_handler_extra_class=conf.cert_handler_extra_class,
generate_cert_info=conf.generate_cert_info, tmp_cert_file=conf.tmp_cert_file, generate_cert_info=conf.generate_cert_info,
tmp_key_file=conf.tmp_key_file, validate_certificate=conf.validate_certificate) tmp_cert_file=conf.tmp_cert_file,
tmp_key_file=conf.tmp_key_file,
validate_certificate=conf.validate_certificate)
class CertHandlerExtra(object): class CertHandlerExtra(object):
@@ -940,7 +949,8 @@ class CertHandlerExtra(object):
def use_generate_cert_func(self): def use_generate_cert_func(self):
raise Exception("use_generate_cert_func function must be implemented") raise Exception("use_generate_cert_func function must be implemented")
def generate_cert(self, generate_cert_info, root_cert_string, root_key_string): def generate_cert(self, generate_cert_info, root_cert_string,
root_key_string):
raise Exception("generate_cert function must be implemented") raise Exception("generate_cert function must be implemented")
#Excepts to return (cert_string, key_string) #Excepts to return (cert_string, key_string)
@@ -953,12 +963,14 @@ class CertHandlerExtra(object):
class CertHandler(object): class CertHandler(object):
def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=None, key_type="pem", def __init__(self, security_context, cert_file=None, cert_type="pem",
generate_cert_info=None, cert_handler_extra_class=None, tmp_cert_file=None, tmp_key_file=None, key_file=None, key_type="pem", generate_cert_info=None,
verify_cert=False): cert_handler_extra_class=None, tmp_cert_file=None,
tmp_key_file=None, verify_cert=False):
""" """
Initiates the class for handling certificates. Enables the certificates to either be a single certificate Initiates the class for handling certificates. Enables the certificates
as base functionality or makes it possible to generate a new certificate for each call to the function. to either be a single certificate as base functionality or makes it
possible to generate a new certificate for each call to the function.
:param key_file: :param key_file:
:param key_type: :param key_type:
:param cert_file: :param cert_file:
@@ -968,7 +980,9 @@ class CertHandler(object):
""" """
self._verify_cert = False self._verify_cert = False
self._generate_cert = False self._generate_cert = False
self._last_cert_verified = None #This cert do not have to be valid, it is just the last cert to be validated. #This cert do not have to be valid, it is just the last cert to be
# validated.
self._last_cert_verified = None
if cert_type == "pem" and key_type == "pem": if cert_type == "pem" and key_type == "pem":
self._verify_cert = verify_cert is True self._verify_cert = verify_cert is True
self._security_context = security_context self._security_context = security_context
@@ -978,7 +992,8 @@ class CertHandler(object):
else: else:
self._key_str = "" self._key_str = ""
if cert_file is not None: if cert_file is not None:
self._cert_str = self._osw.read_str_from_file(cert_file, cert_type) self._cert_str = self._osw.read_str_from_file(cert_file,
cert_type)
else: else:
self._cert_str = "" self._cert_str = ""
@@ -989,8 +1004,9 @@ class CertHandler(object):
self._cert_info = None self._cert_info = None
self._generate_cert_func_active = False self._generate_cert_func_active = False
if generate_cert_info is not None and len(self._cert_str) > 0 and len(self._key_str) > 0 \ if generate_cert_info is not None and len(self._cert_str) > 0 and \
and tmp_key_file is not None and tmp_cert_file is not None: len(self._key_str) > 0 and tmp_key_file is not \
None and tmp_cert_file is not None:
self._generate_cert = True self._generate_cert = True
self._cert_info = generate_cert_info self._cert_info = generate_cert_info
self._cert_handler_extra_class = cert_handler_extra_class self._cert_handler_extra_class = cert_handler_extra_class
@@ -999,8 +1015,10 @@ class CertHandler(object):
if self._verify_cert: if self._verify_cert:
cert_str = self._osw.read_str_from_file(cert_file, "pem") cert_str = self._osw.read_str_from_file(cert_file, "pem")
self._last_validated_cert = cert_str self._last_validated_cert = cert_str
if self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_validate_cert_func(): if self._cert_handler_extra_class is not None and \
self._cert_handler_extra_class.validate_cert(cert_str, self._cert_str, self._key_str) self._cert_handler_extra_class.use_validate_cert_func():
self._cert_handler_extra_class.validate_cert(
cert_str, self._cert_str, self._key_str)
else: else:
valid, mess = self._osw.verify(self._cert_str, cert_str) valid, mess = self._osw.verify(self._cert_str, cert_str)
logger.info("CertHandler.verify_cert: %s" % mess) logger.info("CertHandler.verify_cert: %s" % mess)
@@ -1016,22 +1034,25 @@ class CertHandler(object):
self._tmp_cert_str = client_crt self._tmp_cert_str = client_crt
#No private key for signing #No private key for signing
self._tmp_key_str = "" self._tmp_key_str = ""
elif self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_generate_cert_func(): elif self._cert_handler_extra_class is not None and \
self._cert_handler_extra_class.use_generate_cert_func():
(self._tmp_cert_str, self._tmp_key_str) = \ (self._tmp_cert_str, self._tmp_key_str) = \
self._cert_handler_extra_class.generate_cert(self._cert_info, self._cert_str, self._key_str) self._cert_handler_extra_class.generate_cert(self._cert_info, self._cert_str, self._key_str)
else: else:
self._tmp_cert_str, self._tmp_key_str = self._osw.create_certificate(self._cert_info, request=True) self._tmp_cert_str, self._tmp_key_str = self._osw.create_certificate(self._cert_info, request=True)
self._tmp_cert_str = self._osw.create_cert_signed_certificate(self._cert_str, self._key_str, self._tmp_cert_str = self._osw.create_cert_signed_certificate(
self._tmp_cert_str) self._cert_str, self._key_str, self._tmp_cert_str)
valid, mess = self._osw.verify(self._cert_str, self._tmp_cert_str) valid, mess = self._osw.verify(self._cert_str,
self._tmp_cert_str)
self._osw.write_str_to_file(self._tmp_cert_file, self._tmp_cert_str) self._osw.write_str_to_file(self._tmp_cert_file, self._tmp_cert_str)
self._osw.write_str_to_file(self._tmp_key_file, self._tmp_key_str) self._osw.write_str_to_file(self._tmp_key_file, self._tmp_key_str)
self._security_context.key_file = self._tmp_key_file self._security_context.key_file = self._tmp_key_file
self._security_context.cert_file = self._tmp_cert_file self._security_context.cert_file = self._tmp_cert_file
self._security_context.key_type = "pem" self._security_context.key_type = "pem"
self._security_context.cert_type = "pem" self._security_context.cert_type = "pem"
self._security_context.my_cert = read_cert_from_file(self._security_context.cert_file, self._security_context.my_cert = read_cert_from_file(
self._security_context.cert_type) self._security_context.cert_file,
self._security_context.cert_type)
# How to get a rsa pub key fingerprint from a certificate # How to get a rsa pub key fingerprint from a certificate
@@ -1043,8 +1064,9 @@ class SecurityContext(object):
def __init__(self, crypto, key_file="", key_type="pem", def __init__(self, crypto, key_file="", key_type="pem",
cert_file="", cert_type="pem", metadata=None, cert_file="", cert_type="pem", metadata=None,
debug=False, template="", encrypt_key_type="des-192", debug=False, template="", encrypt_key_type="des-192",
only_use_keys_in_metadata=False, cert_handler_extra_class=None, generate_cert_info=None, only_use_keys_in_metadata=False, cert_handler_extra_class=None,
tmp_cert_file=None, tmp_key_file=None, validate_certificate=None): generate_cert_info=None, tmp_cert_file=None,
tmp_key_file=None, validate_certificate=None):
self.crypto = crypto self.crypto = crypto
assert (isinstance(self.crypto, CryptoBackend)) assert (isinstance(self.crypto, CryptoBackend))
@@ -1059,8 +1081,10 @@ class SecurityContext(object):
self.my_cert = read_cert_from_file(cert_file, cert_type) self.my_cert = read_cert_from_file(cert_file, cert_type)
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, key_type, generate_cert_info, self.cert_handler = CertHandler(self, cert_file, cert_type, key_file,
cert_handler_extra_class, tmp_cert_file, tmp_key_file, validate_certificate) key_type, generate_cert_info,
cert_handler_extra_class, tmp_cert_file,
tmp_key_file, validate_certificate)
self.cert_handler.update_cert(True) self.cert_handler.update_cert(True)
@@ -1135,7 +1159,8 @@ class SecurityContext(object):
) )
def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, def _check_signature(self, decoded_xml, item, node_name=NODE_NAME,
origdoc=None, id_attr="", must=False, only_valid_cert=False): origdoc=None, id_attr="", must=False,
only_valid_cert=False):
#print item #print item
try: try:
issuer = item.issuer.text.strip() issuer = item.issuer.text.strip()
@@ -1179,13 +1204,15 @@ class SecurityContext(object):
try: try:
if self.verify_signature(origdoc, pem_file, if self.verify_signature(origdoc, pem_file,
node_name=node_name, node_name=node_name,
node_id=item.id, id_attr=id_attr): node_id=item.id,
id_attr=id_attr):
verified = True verified = True
break break
except Exception: except Exception:
if self.verify_signature(decoded_xml, pem_file, if self.verify_signature(decoded_xml, pem_file,
node_name=node_name, node_name=node_name,
node_id=item.id, id_attr=id_attr): node_id=item.id,
id_attr=id_attr):
verified = True verified = True
break break
else: else:
@@ -1247,91 +1274,101 @@ class SecurityContext(object):
return msg return msg
return self._check_signature(decoded_xml, msg, class_name(msg), return self._check_signature(decoded_xml, msg, class_name(msg),
origdoc, must=must, only_valid_cert=only_valid_cert) origdoc, must=must,
only_valid_cert=only_valid_cert)
def correctly_signed_authn_request(self, decoded_xml, must=False, def correctly_signed_authn_request(self, decoded_xml, must=False,
origdoc=None, only_valid_cert=False): origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "authn_request", return self.correctly_signed_message(decoded_xml, "authn_request",
must, origdoc, only_valid_cert=only_valid_cert) must, origdoc,
only_valid_cert=only_valid_cert)
def correctly_signed_authn_query(self, decoded_xml, must=False, def correctly_signed_authn_query(self, decoded_xml, must=False,
origdoc=None): origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "authn_query", return self.correctly_signed_message(decoded_xml, "authn_query",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_logout_request(self, decoded_xml, must=False, def correctly_signed_logout_request(self, decoded_xml, must=False,
origdoc=None): origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "logout_request", return self.correctly_signed_message(decoded_xml, "logout_request",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_logout_response(self, decoded_xml, must=False, def correctly_signed_logout_response(self, decoded_xml, must=False,
origdoc=None): origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "logout_response", return self.correctly_signed_message(decoded_xml, "logout_response",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_attribute_query(self, decoded_xml, must=False, def correctly_signed_attribute_query(self, decoded_xml, must=False,
origdoc=None): origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "attribute_query", return self.correctly_signed_message(decoded_xml, "attribute_query",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_authz_decision_query(self, decoded_xml, must=False, def correctly_signed_authz_decision_query(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"authz_decision_query", must, "authz_decision_query", must,
origdoc) origdoc, only_valid_cert)
def correctly_signed_authz_decision_response(self, decoded_xml, must=False, def correctly_signed_authz_decision_response(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"authz_decision_response", must, "authz_decision_response", must,
origdoc) origdoc, only_valid_cert)
def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False, def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"name_id_mapping_request", "name_id_mapping_request",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False, def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"name_id_mapping_response", "name_id_mapping_response",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_artifact_request(self, decoded_xml, must=False, def correctly_signed_artifact_request(self, decoded_xml, must=False,
origdoc=None): origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"artifact_request", "artifact_request",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_artifact_response(self, decoded_xml, must=False, def correctly_signed_artifact_response(self, decoded_xml, must=False,
origdoc=None): origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"artifact_response", "artifact_response",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_manage_name_id_request(self, decoded_xml, must=False, def correctly_signed_manage_name_id_request(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"manage_name_id_request", "manage_name_id_request",
must, origdoc) must, origdoc, only_valid_cert)
def correctly_signed_manage_name_id_response(self, decoded_xml, must=False, def correctly_signed_manage_name_id_response(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"manage_name_id_response", must, "manage_name_id_response", must,
origdoc) origdoc, only_valid_cert)
def correctly_signed_assertion_id_request(self, decoded_xml, must=False, def correctly_signed_assertion_id_request(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, return self.correctly_signed_message(decoded_xml,
"assertion_id_request", must, "assertion_id_request", must,
origdoc) origdoc, only_valid_cert)
def correctly_signed_assertion_id_response(self, decoded_xml, must=False, def correctly_signed_assertion_id_response(self, decoded_xml, must=False,
origdoc=None): origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "assertion", must, return self.correctly_signed_message(decoded_xml, "assertion", must,
origdoc) origdoc, only_valid_cert)
def correctly_signed_response(self, decoded_xml, must=False, origdoc=None): def correctly_signed_response(self, decoded_xml, must=False, origdoc=None):
""" Check if a instance is correctly signed, if we have metadata for """ Check if a instance is correctly signed, if we have metadata for
@@ -1353,11 +1390,12 @@ class SecurityContext(object):
origdoc) origdoc)
if isinstance(response, Response) and (response.assertion or if isinstance(response, Response) and (response.assertion or
response.encrypted_assertion): response.encrypted_assertion):
# Try to find the signing cert in the assertion # Try to find the signing cert in the assertion
for assertion in (response.assertion or response.encrypted_assertion): for assertion in (response.assertion or response.encrypted_assertion):
if response.encrypted_assertion: if response.encrypted_assertion:
decoded_xml = self.decrypt(assertion.encrypted_data.to_string()) decoded_xml = self.decrypt(
assertion.encrypted_data.to_string())
assertion = saml.assertion_from_string(decoded_xml) assertion = saml.assertion_from_string(decoded_xml)
if not assertion.signature: if not assertion.signature: