From a12cc2a979f5203e4c136684aa35741f4a73d1ce Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 4 Oct 2015 10:15:13 -0400 Subject: [PATCH] Fixed various errors. --- src/saml2/client.py | 7 +++++- src/saml2/metadata.py | 53 +++++++++++++++++++++++------------------ src/saml2/response.py | 4 +++- src/saml2/sigver.py | 50 ++++++++++++++++++++++---------------- tests/test_40_sigver.py | 19 ++++++++------- tests/test_51_client.py | 45 ++++++++++++++++++++++++++++++++++ 6 files changed, 125 insertions(+), 53 deletions(-) diff --git a/src/saml2/client.py b/src/saml2/client.py index 5478c47..64871d6 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -206,9 +206,14 @@ class Saml2Client(Base): destination = destinations(srvs)[0] logger.info("destination to provider: %s" % destination) + try: + session_info = self.users.get_info_from(name_id, entity_id) + session_indexes = [session_info['session_index']] + except KeyError: + session_indexes = None req_id, request = self.create_logout_request( destination, entity_id, name_id=name_id, reason=reason, - expire=expire) + expire=expire, session_indexes=session_indexes) # to_sign = [] if binding.startswith("http://"): diff --git a/src/saml2/metadata.py b/src/saml2/metadata.py index 52d8d58..eb33788 100644 --- a/src/saml2/metadata.py +++ b/src/saml2/metadata.py @@ -59,16 +59,17 @@ bMDNS = b'"urn:oasis:names:tc:SAML:2.0:metadata"' XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\"" bXMLNSXS = b" xmlns:xs=\"http://www.w3.org/2001/XMLSchema\"" + def metadata_tostring_fix(desc, nspair, xmlstring=""): if not xmlstring: xmlstring = desc.to_string(nspair) if six.PY2: if "\"xs:string\"" in xmlstring and XMLNSXS not in xmlstring: - xmlstring = xmlstring.replace(MDNS, MDNS+XMLNSXS) + xmlstring = xmlstring.replace(MDNS, MDNS + XMLNSXS) else: if b"\"xs:string\"" in xmlstring and bXMLNSXS not in xmlstring: - xmlstring = xmlstring.replace(bMDNS, bMDNS+bXMLNSXS) + xmlstring = xmlstring.replace(bMDNS, bMDNS + bXMLNSXS) return xmlstring @@ -77,7 +78,7 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None, keyfile=None, mid=None, name=None, sign=None): valid_for = 0 nspair = {"xs": "http://www.w3.org/2001/XMLSchema"} - #paths = [".", "/opt/local/bin"] + # paths = [".", "/opt/local/bin"] if valid: valid_for = int(valid) # Hours @@ -97,11 +98,8 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None, secc = security_context(conf) if mid: - desc = entities_descriptor(eds, valid_for, name, mid, - sign, secc) - valid_instance(desc) - - return metadata_tostring_fix(desc, nspair) + eid, xmldoc = entities_descriptor(eds, valid_for, name, mid, + sign, secc) else: eid = eds[0] if sign: @@ -109,9 +107,8 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None, else: xmldoc = None - valid_instance(eid) - xmldoc = metadata_tostring_fix(eid, nspair, xmldoc) - return xmldoc + valid_instance(eid) + return metadata_tostring_fix(eid, nspair, xmldoc) def _localized_name(val, klass): @@ -346,6 +343,7 @@ def do_idpdisc(discovery_response): return idpdisc.DiscoveryResponse(index="0", location=discovery_response, binding=idpdisc.NAMESPACE) + ENDPOINTS = { "sp": { "artifact_resolution_service": (md.ArtifactResolutionService, True), @@ -425,7 +423,8 @@ def do_endpoints(conf, endpoints): servs = [] i = 1 for args in conf[endpoint]: - if isinstance(args, six.string_types): # Assume it's the location + if isinstance(args, + six.string_types): # Assume it's the location args = {"location": args, "binding": DEFAULT_BINDING[endpoint]} elif isinstance(args, tuple) or isinstance(args, list): @@ -453,16 +452,16 @@ def do_endpoints(conf, endpoints): pass return service + DEFAULT = { "want_assertions_signed": "true", "authn_requests_signed": "false", "want_authn_requests_signed": "false", - #"want_authn_requests_only_with_valid_cert": "false", + # "want_authn_requests_only_with_valid_cert": "false", } def do_attribute_consuming_service(conf, spsso): - service_description = service_name = None requested_attributes = [] acs = conf.attribute_converters @@ -557,7 +556,8 @@ def do_spsso_descriptor(conf, cert=None, enc_cert=None): if cert or enc_cert: metadata_key_usage = conf.metadata_key_usage - spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert, use=metadata_key_usage) + spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert, + use=metadata_key_usage) for key in ["want_assertions_signed", "authn_requests_signed"]: try: @@ -605,10 +605,11 @@ def do_idpsso_descriptor(conf, cert=None, enc_cert=None): idpsso.extensions.add_extension_element(do_uiinfo(ui_info)) if cert or enc_cert: - idpsso.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage) + idpsso.key_descriptor = do_key_descriptor(cert, enc_cert, + use=conf.metadata_key_usage) for key in ["want_authn_requests_signed"]: - #"want_authn_requests_only_with_valid_cert"]: + # "want_authn_requests_only_with_valid_cert"]: try: val = conf.getattr(key, "idp") if val is None: @@ -635,7 +636,8 @@ def do_aa_descriptor(conf, cert=None, enc_cert=None): _do_nameid_format(aad, conf, "aa") if cert or enc_cert: - aad.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage) + aad.key_descriptor = do_key_descriptor(cert, enc_cert, + use=conf.metadata_key_usage) attributes = conf.getattr("attribute", "aa") if attributes: @@ -664,7 +666,8 @@ def do_aq_descriptor(conf, cert=None, enc_cert=None): _do_nameid_format(aqs, conf, "aq") if cert or enc_cert: - aqs.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage) + aqs.key_descriptor = do_key_descriptor(cert, enc_cert, + use=conf.metadata_key_usage) return aqs @@ -685,7 +688,8 @@ def do_pdp_descriptor(conf, cert=None, enc_cert=None): _do_nameid_format(pdp, conf, "pdp") if cert: - pdp.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage) + pdp.key_descriptor = do_key_descriptor(cert, enc_cert, + use=conf.metadata_key_usage) return pdp @@ -702,7 +706,8 @@ def entity_descriptor(confd): if confd.encryption_keypairs is not None: enc_cert = [] for _encryption in confd.encryption_keypairs: - enc_cert.append("".join(open(_encryption["cert_file"]).readlines()[1:-1])) + enc_cert.append( + "".join(open(_encryption["cert_file"]).readlines()[1:-1])) entd = md.EntityDescriptor() entd.entity_id = confd.entityid @@ -736,13 +741,15 @@ def entity_descriptor(confd): entd.idpsso_descriptor = do_idpsso_descriptor(confd, mycert, enc_cert) if "aa" in serves: confd.context = "aa" - entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert, enc_cert) + entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert, + enc_cert) if "pdp" in serves: confd.context = "pdp" entd.pdp_descriptor = do_pdp_descriptor(confd, mycert, enc_cert) if "aq" in serves: confd.context = "aq" - entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert, enc_cert) + entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert, + enc_cert) return entd diff --git a/src/saml2/response.py b/src/saml2/response.py index 3195f23..be9f1fa 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -1036,9 +1036,11 @@ class AuthnResponse(StatusResponse): "issuer": self.issuer(), "not_on_or_after": nooa, "authz_decision_info": self.authz_decision_info()} else: + authn_statement = self.assertion.authn_statement[0] return {"ava": self.ava, "name_id": self.name_id, "came_from": self.came_from, "issuer": self.issuer(), - "not_on_or_after": nooa, "authn_info": self.authn_info()} + "not_on_or_after": nooa, "authn_info": self.authn_info(), + "session_index": authn_statement.session_index} def __str__(self): if not isinstance(self.xmlstr, six.string_types): diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 8c4301e..c8d2daa 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -353,7 +353,10 @@ def make_temp(string, suffix="", decode=True, delete=True): xmlsec function). """ ntf = NamedTemporaryFile(suffix=suffix, delete=delete) - assert isinstance(string, six.binary_type) + # Python3 tempfile requires byte-like object + if not isinstance(string, six.binary_type): + string = string.encode("utf8") + if decode: ntf.write(base64.b64decode(string)) else: @@ -657,6 +660,12 @@ LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" +def make_str(txt): + if isinstance(txt, six.string_types): + return txt + else: + return txt.decode("utf8") + # --------------------------------------------------------------------------- @@ -674,29 +683,32 @@ def read_cert_from_file(cert_file, cert_type): return "" if cert_type == "pem": - line = open(cert_file).read().replace("\r\n", "\n").split("\n") + _a = read_file(cert_file, 'rb').decode("utf8") + _b = _a.replace("\r\n", "\n") + lines = _b.split("\n") - if line[0] == "-----BEGIN CERTIFICATE-----": - line = line[1:] - elif line[0] == "-----BEGIN PUBLIC KEY-----": - line = line[1:] + for pattern in ("-----BEGIN CERTIFICATE-----", + "-----BEGIN PUBLIC KEY-----"): + if pattern in lines: + lines = lines[lines.index(pattern)+1:] + break else: raise CertificateError("Strange beginning of PEM file") - while line[-1] == "": - line = line[:-1] - - if line[-1] == "-----END CERTIFICATE-----": - line = line[:-1] - elif line[-1] == "-----END PUBLIC KEY-----": - line = line[:-1] + for pattern in ("-----END CERTIFICATE-----", + "-----END PUBLIC KEY-----"): + if pattern in lines: + lines = lines[:lines.index(pattern)] + break else: raise CertificateError("Strange end of PEM file") - return "".join(line) + return make_str("".join(lines).encode("utf8")) + if cert_type in ["der", "cer", "crt"]: - data = read_file(cert_file) - return base64.b64encode(str(data)) + data = read_file(cert_file, 'rb') + _cert = base64.b64encode(data) + return make_str(_cert) class CryptoBackend(): @@ -850,8 +862,8 @@ class CryptoBackendXmlSec1(CryptoBackend): 'id','Id' or 'ID' :return: The signed statement """ - if not isinstance(statement, six.binary_type): - statement = str(statement).encode('utf-8') + if isinstance(statement, SamlBase): + statement = str(statement) _, fil = make_temp(statement, suffix=".xml", decode=False, delete=self._xmlsec_delete_tmpfiles) @@ -1284,8 +1296,6 @@ class SecurityContext(object): self.encryption_keypairs = encryption_keypairs self.enc_cert_type = enc_cert_type - - self.my_cert = read_cert_from_file(cert_file, cert_type) self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, diff --git a/tests/test_40_sigver.py b/tests/test_40_sigver.py index 5950c11..00e8479 100644 --- a/tests/test_40_sigver.py +++ b/tests/test_40_sigver.py @@ -1,17 +1,20 @@ #!/usr/bin/env python import base64 -from saml2.sigver import pre_encryption_part, make_temp, XmlsecError, \ - SigverError -from saml2.mdstore import MetadataStore -from saml2.saml import assertion_from_string, EncryptedAssertion -from saml2.samlp import response_from_string - -from saml2 import sigver, extension_elements_to_elements +from saml2 import sigver +from saml2 import extension_elements_to_elements from saml2 import class_name from saml2 import time_util from saml2 import saml, samlp from saml2 import config +from saml2.sigver import pre_encryption_part +from saml2.sigver import make_temp +from saml2.sigver import XmlsecError +from saml2.sigver import SigverError +from saml2.mdstore import MetadataStore +from saml2.saml import assertion_from_string +from saml2.saml import EncryptedAssertion +from saml2.samlp import response_from_string from saml2.s_utils import factory, do_attribute_statement from py.test import raises @@ -510,6 +513,6 @@ def test_xmlsec_err(): if __name__ == "__main__": t = TestSecurity() t.setup_class() - t.test_verify_1() + t.test_sign_assertion() #test_xmlsec_err() diff --git a/tests/test_51_client.py b/tests/test_51_client.py index 2233ee6..71e1733 100644 --- a/tests/test_51_client.py +++ b/tests/test_51_client.py @@ -25,6 +25,7 @@ from saml2.response import LogoutResponse from saml2.saml import NAMEID_FORMAT_PERSISTENT, EncryptedAssertion, Advice from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.saml import NameID +from saml2.samlp import SessionIndex from saml2.server import Server from saml2.sigver import pre_encryption_part, make_temp, pre_encrypt_assertion from saml2.sigver import rm_xmltag @@ -319,6 +320,19 @@ class TestClient: except Exception: # missing certificate self.client.sec.verify_signature(ar_str, node_name=class_name(ar)) + def test_create_logout_request(self): + req_id, req = self.client.create_logout_request( + "http://localhost:8088/slo", "urn:mace:example.com:saml:roland:idp", + name_id=nid, reason="Tired", expire=in_a_while(minutes=15), + session_indexes=["_foo"]) + + assert req.destination == "http://localhost:8088/slo" + assert req.reason == "Tired" + assert req.version == "2.0" + assert req.name_id == nid + assert req.issuer.text == "urn:mace:example.com:saml:roland:sp" + assert req.session_index == [SessionIndex("_foo")] + def test_response_1(self): IDP = "urn:mace:example.com:saml:roland:idp" @@ -359,6 +373,7 @@ class TestClient: assert session_info["came_from"] == "http://foo.example.com/service" response = samlp.response_from_string(authn_response.xmlstr) assert response.destination == "http://lingon.catalogix.se:8087/" + assert "session_index" in session_info # One person in the cache assert len(self.client.users.subjects()) == 1 @@ -1220,6 +1235,36 @@ class TestClient: BINDING_HTTP_REDIRECT) print(res) + def test_do_logout_post(self): + # information about the user from an IdP + session_info = { + "name_id": nid, + "issuer": "urn:mace:example.com:saml:roland:idp", + "not_on_or_after": in_a_while(minutes=15), + "ava": { + "givenName": "Anders", + "surName": "Andersson", + "mail": "anders.andersson@example.com" + }, + "session_index": SessionIndex("_foo") + } + self.client.users.add_information_about_person(session_info) + entity_ids = self.client.users.issuers_of_info(nid) + assert entity_ids == ["urn:mace:example.com:saml:roland:idp"] + resp = self.client.do_logout(nid, entity_ids, "Tired", + in_a_while(minutes=5), sign=True, + expected_binding=BINDING_HTTP_POST) + assert resp + assert len(resp) == 1 + assert list(resp.keys()) == entity_ids + binding, info = resp[entity_ids[0]] + assert binding == BINDING_HTTP_POST + + _dic = unpack_form(info["data"][3]) + res = self.server.parse_logout_request(_dic["SAMLRequest"], + BINDING_HTTP_POST) + assert b'_foo' in res.xmlstr + # Below can only be done with dummy Server IDP = "urn:mace:example.com:saml:roland:idp"