Cleaned up a bit

This commit is contained in:
Roland Hedberg
2014-03-06 19:43:06 +01:00
parent eb12108fb9
commit 74d87d417e

View File

@@ -52,8 +52,16 @@ from saml2.time_util import str_to_time
from tempfile import NamedTemporaryFile
from subprocess import Popen, PIPE
from xmlenc import EncryptionMethod, EncryptedKey, CipherData, CipherValue, \
EncryptedData
from xmlenc import EncryptionMethod
from xmlenc import EncryptedKey
from xmlenc import CipherData
from xmlenc import CipherValue
from xmlenc import EncryptedData
from Crypto.Hash import SHA256
from Crypto.Hash import SHA384
from Crypto.Hash import SHA512
from Crypto.Hash import SHA
logger = logging.getLogger(__name__)
@@ -63,7 +71,6 @@ RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5"
TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc"
from Crypto.Hash import SHA256, SHA384, SHA512, SHA
class SigverError(SAMLError):
@@ -925,12 +932,14 @@ def security_context(conf, debug=None):
raise SigverError('Unknown crypto_backend %s' % (
repr(conf.crypto_backend)))
return SecurityContext(crypto, conf.key_file,
cert_file=conf.cert_file, metadata=metadata,
debug=debug, only_use_keys_in_metadata=_only_md,
cert_handler_extra_class=conf.cert_handler_extra_class,
generate_cert_info=conf.generate_cert_info, tmp_cert_file=conf.tmp_cert_file,
tmp_key_file=conf.tmp_key_file, validate_certificate=conf.validate_certificate)
return SecurityContext(
crypto, conf.key_file, cert_file=conf.cert_file, metadata=metadata,
debug=debug, only_use_keys_in_metadata=_only_md,
cert_handler_extra_class=conf.cert_handler_extra_class,
generate_cert_info=conf.generate_cert_info,
tmp_cert_file=conf.tmp_cert_file,
tmp_key_file=conf.tmp_key_file,
validate_certificate=conf.validate_certificate)
class CertHandlerExtra(object):
@@ -940,7 +949,8 @@ class CertHandlerExtra(object):
def use_generate_cert_func(self):
raise Exception("use_generate_cert_func function must be implemented")
def generate_cert(self, generate_cert_info, root_cert_string, root_key_string):
def generate_cert(self, generate_cert_info, root_cert_string,
root_key_string):
raise Exception("generate_cert function must be implemented")
#Excepts to return (cert_string, key_string)
@@ -953,12 +963,14 @@ class CertHandlerExtra(object):
class CertHandler(object):
def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=None, key_type="pem",
generate_cert_info=None, cert_handler_extra_class=None, tmp_cert_file=None, tmp_key_file=None,
verify_cert=False):
def __init__(self, security_context, cert_file=None, cert_type="pem",
key_file=None, key_type="pem", generate_cert_info=None,
cert_handler_extra_class=None, tmp_cert_file=None,
tmp_key_file=None, verify_cert=False):
"""
Initiates the class for handling certificates. Enables the certificates to either be a single certificate
as base functionality or makes it possible to generate a new certificate for each call to the function.
Initiates the class for handling certificates. Enables the certificates
to either be a single certificate as base functionality or makes it
possible to generate a new certificate for each call to the function.
:param key_file:
:param key_type:
:param cert_file:
@@ -968,7 +980,9 @@ class CertHandler(object):
"""
self._verify_cert = False
self._generate_cert = False
self._last_cert_verified = None #This cert do not have to be valid, it is just the last cert to be validated.
#This cert do not have to be valid, it is just the last cert to be
# validated.
self._last_cert_verified = None
if cert_type == "pem" and key_type == "pem":
self._verify_cert = verify_cert is True
self._security_context = security_context
@@ -978,7 +992,8 @@ class CertHandler(object):
else:
self._key_str = ""
if cert_file is not None:
self._cert_str = self._osw.read_str_from_file(cert_file, cert_type)
self._cert_str = self._osw.read_str_from_file(cert_file,
cert_type)
else:
self._cert_str = ""
@@ -989,8 +1004,9 @@ class CertHandler(object):
self._cert_info = None
self._generate_cert_func_active = False
if generate_cert_info is not None and len(self._cert_str) > 0 and len(self._key_str) > 0 \
and tmp_key_file is not None and tmp_cert_file is not None:
if generate_cert_info is not None and len(self._cert_str) > 0 and \
len(self._key_str) > 0 and tmp_key_file is not \
None and tmp_cert_file is not None:
self._generate_cert = True
self._cert_info = generate_cert_info
self._cert_handler_extra_class = cert_handler_extra_class
@@ -999,8 +1015,10 @@ class CertHandler(object):
if self._verify_cert:
cert_str = self._osw.read_str_from_file(cert_file, "pem")
self._last_validated_cert = cert_str
if self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_validate_cert_func():
self._cert_handler_extra_class.validate_cert(cert_str, self._cert_str, self._key_str)
if self._cert_handler_extra_class is not None and \
self._cert_handler_extra_class.use_validate_cert_func():
self._cert_handler_extra_class.validate_cert(
cert_str, self._cert_str, self._key_str)
else:
valid, mess = self._osw.verify(self._cert_str, cert_str)
logger.info("CertHandler.verify_cert: %s" % mess)
@@ -1016,22 +1034,25 @@ class CertHandler(object):
self._tmp_cert_str = client_crt
#No private key for signing
self._tmp_key_str = ""
elif self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_generate_cert_func():
elif self._cert_handler_extra_class is not None and \
self._cert_handler_extra_class.use_generate_cert_func():
(self._tmp_cert_str, self._tmp_key_str) = \
self._cert_handler_extra_class.generate_cert(self._cert_info, self._cert_str, self._key_str)
else:
self._tmp_cert_str, self._tmp_key_str = self._osw.create_certificate(self._cert_info, request=True)
self._tmp_cert_str = self._osw.create_cert_signed_certificate(self._cert_str, self._key_str,
self._tmp_cert_str)
valid, mess = self._osw.verify(self._cert_str, self._tmp_cert_str)
self._tmp_cert_str = self._osw.create_cert_signed_certificate(
self._cert_str, self._key_str, self._tmp_cert_str)
valid, mess = self._osw.verify(self._cert_str,
self._tmp_cert_str)
self._osw.write_str_to_file(self._tmp_cert_file, self._tmp_cert_str)
self._osw.write_str_to_file(self._tmp_key_file, self._tmp_key_str)
self._security_context.key_file = self._tmp_key_file
self._security_context.cert_file = self._tmp_cert_file
self._security_context.key_type = "pem"
self._security_context.cert_type = "pem"
self._security_context.my_cert = read_cert_from_file(self._security_context.cert_file,
self._security_context.cert_type)
self._security_context.my_cert = read_cert_from_file(
self._security_context.cert_file,
self._security_context.cert_type)
# How to get a rsa pub key fingerprint from a certificate
@@ -1043,8 +1064,9 @@ class SecurityContext(object):
def __init__(self, crypto, key_file="", key_type="pem",
cert_file="", cert_type="pem", metadata=None,
debug=False, template="", encrypt_key_type="des-192",
only_use_keys_in_metadata=False, cert_handler_extra_class=None, generate_cert_info=None,
tmp_cert_file=None, tmp_key_file=None, validate_certificate=None):
only_use_keys_in_metadata=False, cert_handler_extra_class=None,
generate_cert_info=None, tmp_cert_file=None,
tmp_key_file=None, validate_certificate=None):
self.crypto = crypto
assert (isinstance(self.crypto, CryptoBackend))
@@ -1059,8 +1081,10 @@ class SecurityContext(object):
self.my_cert = read_cert_from_file(cert_file, cert_type)
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, key_type, generate_cert_info,
cert_handler_extra_class, tmp_cert_file, tmp_key_file, validate_certificate)
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file,
key_type, generate_cert_info,
cert_handler_extra_class, tmp_cert_file,
tmp_key_file, validate_certificate)
self.cert_handler.update_cert(True)
@@ -1135,7 +1159,8 @@ class SecurityContext(object):
)
def _check_signature(self, decoded_xml, item, node_name=NODE_NAME,
origdoc=None, id_attr="", must=False, only_valid_cert=False):
origdoc=None, id_attr="", must=False,
only_valid_cert=False):
#print item
try:
issuer = item.issuer.text.strip()
@@ -1179,13 +1204,15 @@ class SecurityContext(object):
try:
if self.verify_signature(origdoc, pem_file,
node_name=node_name,
node_id=item.id, id_attr=id_attr):
node_id=item.id,
id_attr=id_attr):
verified = True
break
except Exception:
if self.verify_signature(decoded_xml, pem_file,
node_name=node_name,
node_id=item.id, id_attr=id_attr):
node_id=item.id,
id_attr=id_attr):
verified = True
break
else:
@@ -1247,91 +1274,101 @@ class SecurityContext(object):
return msg
return self._check_signature(decoded_xml, msg, class_name(msg),
origdoc, must=must, only_valid_cert=only_valid_cert)
origdoc, must=must,
only_valid_cert=only_valid_cert)
def correctly_signed_authn_request(self, decoded_xml, must=False,
origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "authn_request",
must, origdoc, only_valid_cert=only_valid_cert)
must, origdoc,
only_valid_cert=only_valid_cert)
def correctly_signed_authn_query(self, decoded_xml, must=False,
origdoc=None):
origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "authn_query",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_logout_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "logout_request",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_logout_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "logout_response",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_attribute_query(self, decoded_xml, must=False,
origdoc=None):
origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "attribute_query",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_authz_decision_query(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"authz_decision_query", must,
origdoc)
origdoc, only_valid_cert)
def correctly_signed_authz_decision_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"authz_decision_response", must,
origdoc)
origdoc, only_valid_cert)
def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"name_id_mapping_request",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"name_id_mapping_response",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_artifact_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"artifact_request",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_artifact_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None, only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"artifact_response",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_manage_name_id_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"manage_name_id_request",
must, origdoc)
must, origdoc, only_valid_cert)
def correctly_signed_manage_name_id_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"manage_name_id_response", must,
origdoc)
origdoc, only_valid_cert)
def correctly_signed_assertion_id_request(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml,
"assertion_id_request", must,
origdoc)
origdoc, only_valid_cert)
def correctly_signed_assertion_id_response(self, decoded_xml, must=False,
origdoc=None):
origdoc=None,
only_valid_cert=False):
return self.correctly_signed_message(decoded_xml, "assertion", must,
origdoc)
origdoc, only_valid_cert)
def correctly_signed_response(self, decoded_xml, must=False, origdoc=None):
""" Check if a instance is correctly signed, if we have metadata for
@@ -1353,11 +1390,12 @@ class SecurityContext(object):
origdoc)
if isinstance(response, Response) and (response.assertion or
response.encrypted_assertion):
response.encrypted_assertion):
# Try to find the signing cert in the assertion
for assertion in (response.assertion or response.encrypted_assertion):
if response.encrypted_assertion:
decoded_xml = self.decrypt(assertion.encrypted_data.to_string())
decoded_xml = self.decrypt(
assertion.encrypted_data.to_string())
assertion = saml.assertion_from_string(decoded_xml)
if not assertion.signature: