Added PEFIM changes that had been removed.

This commit is contained in:
Hans Hörberg
2015-05-25 17:38:18 +02:00
parent 93bdb86a81
commit d6d76e1efb
4 changed files with 53 additions and 24 deletions

View File

@@ -2,7 +2,7 @@
import saml2 import saml2
from saml2 import SamlBase from saml2 import SamlBase
from saml2.xmldsig import X509Data from saml2.xmldsig import KeyInfo
NAMESPACE = 'urn:net:eustix:names:tc:PEFIM:0.0:assertion' NAMESPACE = 'urn:net:eustix:names:tc:PEFIM:0.0:assertion'
@@ -16,11 +16,16 @@ class SPCertEncType_(SamlBase):
c_attributes = SamlBase.c_attributes.copy() c_attributes = SamlBase.c_attributes.copy()
c_child_order = SamlBase.c_child_order[:] c_child_order = SamlBase.c_child_order[:]
c_cardinality = SamlBase.c_cardinality.copy() c_cardinality = SamlBase.c_cardinality.copy()
c_children['{http://www.w3.org/2000/09/xmldsig#}X509Data'] = ('x509_data', c_children['{http://www.w3.org/2000/09/xmldsig#}KeyInfo'] = ('key_info',
[X509Data]) [KeyInfo])
c_cardinality['key_info'] = {"min": 1}
c_attributes['VerifyDepth'] = ('verify_depth', 'unsignedByte', False)
c_child_order.extend(['key_info'])
def __init__(self, def __init__(self,
key_info=None,
x509_data=None, x509_data=None,
verify_depth='1',
text=None, text=None,
extension_elements=None, extension_elements=None,
extension_attributes=None): extension_attributes=None):
@@ -28,7 +33,14 @@ class SPCertEncType_(SamlBase):
text=text, text=text,
extension_elements=extension_elements, extension_elements=extension_elements,
extension_attributes=extension_attributes) extension_attributes=extension_attributes)
self.x509_data = x509_data if key_info:
self.key_info = key_info
elif x509_data:
self.key_info = KeyInfo(x509_data=x509_data)
else:
self.key_info = []
self.verify_depth = verify_depth
#self.x509_data = x509_data
def spcertenc_type__from_string(xml_string): def spcertenc_type__from_string(xml_string):
@@ -63,4 +75,3 @@ ELEMENT_BY_TAG = {
def factory(tag, **kwargs): def factory(tag, **kwargs):
return ELEMENT_BY_TAG[tag](**kwargs) return ELEMENT_BY_TAG[tag](**kwargs)

View File

@@ -871,11 +871,16 @@ class AuthnResponse(StatusResponse):
logger.debug("***Encrypted assertion/-s***") logger.debug("***Encrypted assertion/-s***")
decr_text = "%s" % self.response decr_text = "%s" % self.response
resp = self.response resp = self.response
while self.find_encrypt_data(resp): decr_text_old = None
while self.find_encrypt_data(resp) and decr_text_old != decr_text:
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys) decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text) resp = samlp.response_from_string(decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text) _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text)
while self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions): decr_text_old = None
while self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions) and \
decr_text_old != decr_text:
decr_text_old = decr_text
decr_text = self.sec.decrypt_keys(decr_text, keys) decr_text = self.sec.decrypt_keys(decr_text, keys)
resp = samlp.response_from_string(decr_text) resp = samlp.response_from_string(decr_text)
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True) _enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
@@ -893,7 +898,8 @@ class AuthnResponse(StatusResponse):
tmp_ass.advice.assertion.extend(advice_res) tmp_ass.advice.assertion.extend(advice_res)
else: else:
tmp_ass.advice.assertion = advice_res tmp_ass.advice.assertion = advice_res
tmp_ass.advice.encrypted_assertion = [] if len(advice_res) > 0:
tmp_ass.advice.encrypted_assertion = []
self.response.assertion = resp.assertion self.response.assertion = resp.assertion
for assertion in _enc_assertions: for assertion in _enc_assertions:
if not self._assertion(assertion, True): if not self._assertion(assertion, True):
@@ -902,7 +908,8 @@ class AuthnResponse(StatusResponse):
self.assertions.append(assertion) self.assertions.append(assertion)
self.xmlstr = decr_text self.xmlstr = decr_text
self.response.encrypted_assertion = [] if len(_enc_assertions) > 0:
self.response.encrypted_assertion = []
if self.response.assertion: if self.response.assertion:
for assertion in self.response.assertion: for assertion in self.response.assertion:

View File

@@ -41,6 +41,7 @@ from saml2 import VERSION
from saml2.cert import OpenSSLWrapper from saml2.cert import OpenSSLWrapper
from saml2.extension import pefim from saml2.extension import pefim
from saml2.extension.pefim import SPCertEnc
from saml2.saml import EncryptedAssertion from saml2.saml import EncryptedAssertion
import saml2.xmldsig as ds import saml2.xmldsig as ds
@@ -1066,21 +1067,30 @@ def security_context(conf, debug=None):
def encrypt_cert_from_item(item): def encrypt_cert_from_item(item):
_encrypt_cert = None _encrypt_cert = None
try: try:
_elem = extension_elements_to_elements(item.extension_elements[0].children, try:
[pefim, ds]) _elem = extension_elements_to_elements(item.extensions.extension_elements,[pefim, ds])
if len(_elem) == 1: except:
_encrypt_cert = _elem[0].x509_data[0].x509_certificate.text _elem = extension_elements_to_elements(item.extension_elements[0].children,
#else: [pefim, ds])
# certs = cert_from_instance(item)
# if len(certs) > 0: for _tmp_elem in _elem:
# _encrypt_cert = certs[0] if isinstance(_tmp_elem, SPCertEnc):
except Exception: for _tmp_key_info in _tmp_elem.key_info:
if _tmp_key_info.x509_data is not None and len(_tmp_key_info.x509_data) > 0:
_encrypt_cert = _tmp_key_info.x509_data[0].x509_certificate.text
break
#_encrypt_cert = _elem[0].x509_data[0].x509_certificate.text
# else:
# certs = cert_from_instance(item)
# if len(certs) > 0:
# _encrypt_cert = certs[0]
except Exception as _exception:
pass pass
#if _encrypt_cert is None: # if _encrypt_cert is None:
# certs = cert_from_instance(item) # certs = cert_from_instance(item)
# if len(certs) > 0: # if len(certs) > 0:
# _encrypt_cert = certs[0] # _encrypt_cert = certs[0]
if _encrypt_cert is not None: if _encrypt_cert is not None:
if _encrypt_cert.find("-----BEGIN CERTIFICATE-----\n") == -1: if _encrypt_cert.find("-----BEGIN CERTIFICATE-----\n") == -1:
@@ -1090,6 +1100,7 @@ def encrypt_cert_from_item(item):
return _encrypt_cert return _encrypt_cert
class CertHandlerExtra(object): class CertHandlerExtra(object):
def __init__(self): def __init__(self):
pass pass

View File

@@ -48,5 +48,5 @@ _elem = extension_elements_to_elements(parsed.extensions.extension_elements,
assert len(_elem) == 1 assert len(_elem) == 1
_spcertenc = _elem[0] _spcertenc = _elem[0]
_cert = _spcertenc.x509_data[0].x509_certificate.text _cert = _spcertenc.key_info[0].x509_data[0].x509_certificate.text
assert cert == _cert assert cert == _cert