diff --git a/src/saml2/extension/pefim.py b/src/saml2/extension/pefim.py index 2e05668..f4810a4 100644 --- a/src/saml2/extension/pefim.py +++ b/src/saml2/extension/pefim.py @@ -3,6 +3,7 @@ import saml2 from saml2 import SamlBase from xmldsig import X509Data +from xmldsig import KeyInfo NAMESPACE = 'urn:net:eustix:names:tc:PEFIM:0.0:assertion' @@ -16,11 +17,16 @@ class SPCertEncType_(SamlBase): c_attributes = SamlBase.c_attributes.copy() c_child_order = SamlBase.c_child_order[:] c_cardinality = SamlBase.c_cardinality.copy() - c_children['{http://www.w3.org/2000/09/xmldsig#}X509Data'] = ('x509_data', - [X509Data]) + c_children['{http://www.w3.org/2000/09/xmldsig#}KeyInfo'] = ('key_info', + [KeyInfo]) + c_cardinality['key_info'] = {"min": 1} + c_attributes['VerifyDepth'] = ('verify_depth', 'unsignedByte', False) + c_child_order.extend(['key_info']) def __init__(self, + key_info=None, x509_data=None, + verify_depth='1', text=None, extension_elements=None, extension_attributes=None): @@ -28,7 +34,14 @@ class SPCertEncType_(SamlBase): text=text, extension_elements=extension_elements, extension_attributes=extension_attributes) - self.x509_data = x509_data + if key_info: + self.key_info = key_info + elif x509_data: + self.key_info = KeyInfo(x509_data=x509_data) + else: + self.key_info = [] + self.verify_depth = verify_depth + #self.x509_data = x509_data def spcertenc_type__from_string(xml_string): diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index efe96fe..be36246 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -21,6 +21,7 @@ from Crypto.Util.asn1 import DerSequence from Crypto.PublicKey import RSA from saml2.cert import OpenSSLWrapper from saml2.extension import pefim +from saml2.extension.pefim import SPCertEnc from saml2.saml import EncryptedAssertion import xmldsig as ds @@ -1063,19 +1064,24 @@ def encrypt_cert_from_item(item): try: _elem = extension_elements_to_elements(item.extension_elements[0].children, [pefim, ds]) - if len(_elem) == 1: - _encrypt_cert = _elem[0].x509_data[0].x509_certificate.text - else: - certs = cert_from_instance(item) - if len(certs) > 0: - _encrypt_cert = certs[0] - except Exception: + for _tmp_elem in _elem: + if isinstance(_tmp_elem, SPCertEnc): + for _tmp_key_info in _tmp_elem.key_info: + if _tmp_key_info.x509_data is not None and len(_tmp_key_info.x509_data) > 0: + _encrypt_cert = _tmp_key_info.x509_data[0].x509_certificate.text + break + #_encrypt_cert = _elem[0].x509_data[0].x509_certificate.text +# else: +# certs = cert_from_instance(item) +# if len(certs) > 0: +# _encrypt_cert = certs[0] + except Exception as _exception: pass - if _encrypt_cert is None: - certs = cert_from_instance(item) - if len(certs) > 0: - _encrypt_cert = certs[0] +# if _encrypt_cert is None: +# certs = cert_from_instance(item) +# if len(certs) > 0: +# _encrypt_cert = certs[0] if _encrypt_cert is not None: if _encrypt_cert.find("-----BEGIN CERTIFICATE-----\n") == -1: diff --git a/tests/test_82_pefim.py b/tests/test_82_pefim.py index cc82551..f40608c 100644 --- a/tests/test_82_pefim.py +++ b/tests/test_82_pefim.py @@ -48,5 +48,5 @@ _elem = extension_elements_to_elements(parsed.extensions.extension_elements, assert len(_elem) == 1 _spcertenc = _elem[0] -_cert = _spcertenc.x509_data[0].x509_certificate.text +_cert = _spcertenc.key_info[0].x509_data[0].x509_certificate.text assert cert == _cert