diff --git a/src/saml2/entity.py b/src/saml2/entity.py index 26c2fe5..92db817 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -36,7 +36,7 @@ from saml2.s_utils import rndstr from saml2.s_utils import success_status_factory from saml2.s_utils import decode_base64_and_inflate from saml2.s_utils import UnsupportedBinding -from saml2.samlp import AuthnRequest, SessionIndex +from saml2.samlp import AuthnRequest, SessionIndex, response_from_string from saml2.samlp import AuthzDecisionQuery from saml2.samlp import AuthnQuery from saml2.samlp import AssertionIDRequest @@ -504,7 +504,7 @@ class Entity(HTTPBase): def _response(self, in_response_to, consumer_url=None, status=None, issuer=None, sign=False, to_sign=None, encrypt_assertion=False, encrypt_assertion_self_contained=False, encrypted_advice_attributes=False, - encrypt_cert=None, **kwargs): + encrypt_cert=None,sign_assertion=None, **kwargs): """ Create a Response. Encryption: encrypt_assertion must be true for encryption to be performed. If encrypted_advice_attributes also is @@ -541,13 +541,14 @@ class Entity(HTTPBase): if not sign and to_sign and not encrypt_assertion: return signed_instance_factory(response, self.sec, to_sign) - if encrypt_assertion: - node_xpath = None + if encrypt_assertion or (encrypted_advice_attributes and response.assertion.advice is not None and + len(response.assertion.advice.assertion) == 1): if sign: response.signature = pre_signature_part(response.id, self.sec.my_cert, 1) sign_class = [(class_name(response), response.id)] cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) + encrypt_advice = False if encrypted_advice_attributes and response.assertion.advice is not None \ and len(response.assertion.advice.assertion) == 1: tmp_assertion = response.assertion.advice.assertion[0] @@ -558,26 +559,59 @@ class Entity(HTTPBase): else: response.assertion.advice.encrypted_assertion[0].add_extension_element(tmp_assertion) response.assertion.advice.assertion = [] + to_sign_advice = [] + if sign_assertion is not None and sign_assertion: + if response.assertion.advice and response.assertion.advice.assertion: + for tmp_assertion in response.assertion.advice.assertion: + tmp_assertion.signature = pre_signature_part(tmp_assertion.id, self.sec.my_cert, 1) + to_sign_advice.append((class_name(tmp_assertion), tmp_assertion.id)) if encrypt_assertion_self_contained: advice_tag = response.assertion.advice._to_element_tree().tag assertion_tag = tmp_assertion._to_element_tree().tag - response = response.get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion( - assertion_tag, advice_tag) + response = response.\ + get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion(assertion_tag, + advice_tag) node_xpath = ''.join(["/*[local-name()=\"%s\"]" % v for v in ["Response", "Assertion", "Advice", "EncryptedAssertion", "Assertion"]]) - elif encrypt_assertion_self_contained: - assertion_tag = response.assertion._to_element_tree().tag - response = pre_encrypt_assertion(response) - response = response.get_xml_string_with_self_contained_assertion_within_encrypted_assertion( - assertion_tag) - else: - response = pre_encrypt_assertion(response) - if to_sign: - response = signed_instance_factory(response, self.sec, to_sign) - _, cert_file = make_temp("%s" % encrypt_cert, decode=False) - response = cbxs.encrypt_assertion(response, cert_file, - pre_encryption_part(), node_xpath=node_xpath) - # template(response.assertion.id)) + + if to_sign_advice: + response = signed_instance_factory(response, self.sec, to_sign_advice) + _, cert_file = make_temp("%s" % encrypt_cert, decode=False) + response = cbxs.encrypt_assertion(response, cert_file, + pre_encryption_part(), node_xpath=node_xpath) + encrypt_advice = True + if encrypt_assertion: + response = response_from_string(response) + if encrypt_assertion: + if encrypt_assertion_self_contained: + assertion_tag = None + try: + assertion_tag = response.assertion._to_element_tree().tag + except: + assertion_tag = response.assertion[0]._to_element_tree().tag + response = pre_encrypt_assertion(response) + response = response.get_xml_string_with_self_contained_assertion_within_encrypted_assertion( + assertion_tag) + else: + response = pre_encrypt_assertion(response) + to_sign_assertion = [] + if sign_assertion is not None and sign_assertion: + response.assertion.signature = pre_signature_part(response.assertion.id, self.sec.my_cert, 1) + to_sign_assertion.append((class_name(response.assertion), response.assertion.id)) + if to_sign_assertion: + response = signed_instance_factory(response, self.sec, to_sign_assertion) + if encrypt_cert is not None and not encrypt_advice: + _, cert_file = make_temp("%s" % encrypt_cert, decode=False) + else: + tmp_cert_str = "%s" % self.sec.my_cert + if "-----BEGIN CERTIFICATE-----" not in tmp_cert_str: + tmp_cert_str = "-----BEGIN CERTIFICATE-----\n" + tmp_cert_str + if "-----END CERTIFICATE-----" not in tmp_cert_str: + tmp_cert_str = tmp_cert_str + "\n-----END CERTIFICATE-----\n" + _, cert_file = make_temp(tmp_cert_str, decode=False) + response = cbxs.encrypt_assertion(response, cert_file, + pre_encryption_part()) + # template(response.assertion.id)) if sign: return signed_instance_factory(response, self.sec, sign_class) else: diff --git a/src/saml2/server.py b/src/saml2/server.py index c6d3169..f876e28 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -357,7 +357,7 @@ class Server(Entity): # tmp_authn_statement = authn_statement # authn_statement = None - if encrypt_assertion and encrypted_advice_attributes: + if encrypted_advice_attributes: assertion_attributes = self.setup_assertion(None, sp_entity_id, None, None, None, policy, None, None, identity, best_effort, sign_response, False) assertion = self.setup_assertion(authn, sp_entity_id, in_response_to, consumer_url, @@ -374,15 +374,15 @@ class Server(Entity): sign_response) to_sign = [] - if sign_assertion is not None and sign_assertion: - if assertion.advice and assertion.advice.assertion: - for tmp_assertion in assertion.advice.assertion: - tmp_assertion.signature = pre_signature_part(tmp_assertion.id, self.sec.my_cert, 1) - to_sign.append((class_name(tmp_assertion), tmp_assertion.id)) - assertion.signature = pre_signature_part(assertion.id, - self.sec.my_cert, 1) + #if sign_assertion is not None and sign_assertion: + # if assertion.advice and assertion.advice.assertion: + # for tmp_assertion in assertion.advice.assertion: + # tmp_assertion.signature = pre_signature_part(tmp_assertion.id, self.sec.my_cert, 1) + # to_sign.append((class_name(tmp_assertion), tmp_assertion.id)) + # assertion.signature = pre_signature_part(assertion.id, + # self.sec.my_cert, 1) # Just the assertion or the response and the assertion ? - to_sign.append((class_name(assertion), assertion.id)) + # to_sign.append((class_name(assertion), assertion.id)) # Store which assertion that has been sent to which SP about which @@ -401,7 +401,8 @@ class Server(Entity): sign_response, to_sign, encrypt_assertion=encrypt_assertion, encrypt_cert=encrypt_cert, encrypt_assertion_self_contained=encrypt_assertion_self_contained, - encrypted_advice_attributes=encrypted_advice_attributes, **args) + encrypted_advice_attributes=encrypted_advice_attributes,sign_assertion=sign_assertion, + **args) # ------------------------------------------------------------------------