Fixed various errors.
This commit is contained in:
@@ -206,9 +206,14 @@ class Saml2Client(Base):
|
||||
|
||||
destination = destinations(srvs)[0]
|
||||
logger.info("destination to provider: %s" % destination)
|
||||
try:
|
||||
session_info = self.users.get_info_from(name_id, entity_id)
|
||||
session_indexes = [session_info['session_index']]
|
||||
except KeyError:
|
||||
session_indexes = None
|
||||
req_id, request = self.create_logout_request(
|
||||
destination, entity_id, name_id=name_id, reason=reason,
|
||||
expire=expire)
|
||||
expire=expire, session_indexes=session_indexes)
|
||||
|
||||
# to_sign = []
|
||||
if binding.startswith("http://"):
|
||||
|
@@ -59,16 +59,17 @@ bMDNS = b'"urn:oasis:names:tc:SAML:2.0:metadata"'
|
||||
XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
|
||||
bXMLNSXS = b" xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
|
||||
|
||||
|
||||
def metadata_tostring_fix(desc, nspair, xmlstring=""):
|
||||
if not xmlstring:
|
||||
xmlstring = desc.to_string(nspair)
|
||||
|
||||
if six.PY2:
|
||||
if "\"xs:string\"" in xmlstring and XMLNSXS not in xmlstring:
|
||||
xmlstring = xmlstring.replace(MDNS, MDNS+XMLNSXS)
|
||||
xmlstring = xmlstring.replace(MDNS, MDNS + XMLNSXS)
|
||||
else:
|
||||
if b"\"xs:string\"" in xmlstring and bXMLNSXS not in xmlstring:
|
||||
xmlstring = xmlstring.replace(bMDNS, bMDNS+bXMLNSXS)
|
||||
xmlstring = xmlstring.replace(bMDNS, bMDNS + bXMLNSXS)
|
||||
|
||||
return xmlstring
|
||||
|
||||
@@ -77,7 +78,7 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
|
||||
keyfile=None, mid=None, name=None, sign=None):
|
||||
valid_for = 0
|
||||
nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
|
||||
#paths = [".", "/opt/local/bin"]
|
||||
# paths = [".", "/opt/local/bin"]
|
||||
|
||||
if valid:
|
||||
valid_for = int(valid) # Hours
|
||||
@@ -97,11 +98,8 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
|
||||
secc = security_context(conf)
|
||||
|
||||
if mid:
|
||||
desc = entities_descriptor(eds, valid_for, name, mid,
|
||||
sign, secc)
|
||||
valid_instance(desc)
|
||||
|
||||
return metadata_tostring_fix(desc, nspair)
|
||||
eid, xmldoc = entities_descriptor(eds, valid_for, name, mid,
|
||||
sign, secc)
|
||||
else:
|
||||
eid = eds[0]
|
||||
if sign:
|
||||
@@ -109,9 +107,8 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
|
||||
else:
|
||||
xmldoc = None
|
||||
|
||||
valid_instance(eid)
|
||||
xmldoc = metadata_tostring_fix(eid, nspair, xmldoc)
|
||||
return xmldoc
|
||||
valid_instance(eid)
|
||||
return metadata_tostring_fix(eid, nspair, xmldoc)
|
||||
|
||||
|
||||
def _localized_name(val, klass):
|
||||
@@ -346,6 +343,7 @@ def do_idpdisc(discovery_response):
|
||||
return idpdisc.DiscoveryResponse(index="0", location=discovery_response,
|
||||
binding=idpdisc.NAMESPACE)
|
||||
|
||||
|
||||
ENDPOINTS = {
|
||||
"sp": {
|
||||
"artifact_resolution_service": (md.ArtifactResolutionService, True),
|
||||
@@ -425,7 +423,8 @@ def do_endpoints(conf, endpoints):
|
||||
servs = []
|
||||
i = 1
|
||||
for args in conf[endpoint]:
|
||||
if isinstance(args, six.string_types): # Assume it's the location
|
||||
if isinstance(args,
|
||||
six.string_types): # Assume it's the location
|
||||
args = {"location": args,
|
||||
"binding": DEFAULT_BINDING[endpoint]}
|
||||
elif isinstance(args, tuple) or isinstance(args, list):
|
||||
@@ -453,16 +452,16 @@ def do_endpoints(conf, endpoints):
|
||||
pass
|
||||
return service
|
||||
|
||||
|
||||
DEFAULT = {
|
||||
"want_assertions_signed": "true",
|
||||
"authn_requests_signed": "false",
|
||||
"want_authn_requests_signed": "false",
|
||||
#"want_authn_requests_only_with_valid_cert": "false",
|
||||
# "want_authn_requests_only_with_valid_cert": "false",
|
||||
}
|
||||
|
||||
|
||||
def do_attribute_consuming_service(conf, spsso):
|
||||
|
||||
service_description = service_name = None
|
||||
requested_attributes = []
|
||||
acs = conf.attribute_converters
|
||||
@@ -557,7 +556,8 @@ def do_spsso_descriptor(conf, cert=None, enc_cert=None):
|
||||
|
||||
if cert or enc_cert:
|
||||
metadata_key_usage = conf.metadata_key_usage
|
||||
spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert, use=metadata_key_usage)
|
||||
spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert,
|
||||
use=metadata_key_usage)
|
||||
|
||||
for key in ["want_assertions_signed", "authn_requests_signed"]:
|
||||
try:
|
||||
@@ -605,10 +605,11 @@ def do_idpsso_descriptor(conf, cert=None, enc_cert=None):
|
||||
idpsso.extensions.add_extension_element(do_uiinfo(ui_info))
|
||||
|
||||
if cert or enc_cert:
|
||||
idpsso.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
|
||||
idpsso.key_descriptor = do_key_descriptor(cert, enc_cert,
|
||||
use=conf.metadata_key_usage)
|
||||
|
||||
for key in ["want_authn_requests_signed"]:
|
||||
#"want_authn_requests_only_with_valid_cert"]:
|
||||
# "want_authn_requests_only_with_valid_cert"]:
|
||||
try:
|
||||
val = conf.getattr(key, "idp")
|
||||
if val is None:
|
||||
@@ -635,7 +636,8 @@ def do_aa_descriptor(conf, cert=None, enc_cert=None):
|
||||
_do_nameid_format(aad, conf, "aa")
|
||||
|
||||
if cert or enc_cert:
|
||||
aad.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
|
||||
aad.key_descriptor = do_key_descriptor(cert, enc_cert,
|
||||
use=conf.metadata_key_usage)
|
||||
|
||||
attributes = conf.getattr("attribute", "aa")
|
||||
if attributes:
|
||||
@@ -664,7 +666,8 @@ def do_aq_descriptor(conf, cert=None, enc_cert=None):
|
||||
_do_nameid_format(aqs, conf, "aq")
|
||||
|
||||
if cert or enc_cert:
|
||||
aqs.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
|
||||
aqs.key_descriptor = do_key_descriptor(cert, enc_cert,
|
||||
use=conf.metadata_key_usage)
|
||||
|
||||
return aqs
|
||||
|
||||
@@ -685,7 +688,8 @@ def do_pdp_descriptor(conf, cert=None, enc_cert=None):
|
||||
_do_nameid_format(pdp, conf, "pdp")
|
||||
|
||||
if cert:
|
||||
pdp.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
|
||||
pdp.key_descriptor = do_key_descriptor(cert, enc_cert,
|
||||
use=conf.metadata_key_usage)
|
||||
|
||||
return pdp
|
||||
|
||||
@@ -702,7 +706,8 @@ def entity_descriptor(confd):
|
||||
if confd.encryption_keypairs is not None:
|
||||
enc_cert = []
|
||||
for _encryption in confd.encryption_keypairs:
|
||||
enc_cert.append("".join(open(_encryption["cert_file"]).readlines()[1:-1]))
|
||||
enc_cert.append(
|
||||
"".join(open(_encryption["cert_file"]).readlines()[1:-1]))
|
||||
|
||||
entd = md.EntityDescriptor()
|
||||
entd.entity_id = confd.entityid
|
||||
@@ -736,13 +741,15 @@ def entity_descriptor(confd):
|
||||
entd.idpsso_descriptor = do_idpsso_descriptor(confd, mycert, enc_cert)
|
||||
if "aa" in serves:
|
||||
confd.context = "aa"
|
||||
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert, enc_cert)
|
||||
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert,
|
||||
enc_cert)
|
||||
if "pdp" in serves:
|
||||
confd.context = "pdp"
|
||||
entd.pdp_descriptor = do_pdp_descriptor(confd, mycert, enc_cert)
|
||||
if "aq" in serves:
|
||||
confd.context = "aq"
|
||||
entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert, enc_cert)
|
||||
entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert,
|
||||
enc_cert)
|
||||
|
||||
return entd
|
||||
|
||||
|
@@ -1036,9 +1036,11 @@ class AuthnResponse(StatusResponse):
|
||||
"issuer": self.issuer(), "not_on_or_after": nooa,
|
||||
"authz_decision_info": self.authz_decision_info()}
|
||||
else:
|
||||
authn_statement = self.assertion.authn_statement[0]
|
||||
return {"ava": self.ava, "name_id": self.name_id,
|
||||
"came_from": self.came_from, "issuer": self.issuer(),
|
||||
"not_on_or_after": nooa, "authn_info": self.authn_info()}
|
||||
"not_on_or_after": nooa, "authn_info": self.authn_info(),
|
||||
"session_index": authn_statement.session_index}
|
||||
|
||||
def __str__(self):
|
||||
if not isinstance(self.xmlstr, six.string_types):
|
||||
|
@@ -353,7 +353,10 @@ def make_temp(string, suffix="", decode=True, delete=True):
|
||||
xmlsec function).
|
||||
"""
|
||||
ntf = NamedTemporaryFile(suffix=suffix, delete=delete)
|
||||
assert isinstance(string, six.binary_type)
|
||||
# Python3 tempfile requires byte-like object
|
||||
if not isinstance(string, six.binary_type):
|
||||
string = string.encode("utf8")
|
||||
|
||||
if decode:
|
||||
ntf.write(base64.b64decode(string))
|
||||
else:
|
||||
@@ -657,6 +660,12 @@ LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||
LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||
|
||||
|
||||
def make_str(txt):
|
||||
if isinstance(txt, six.string_types):
|
||||
return txt
|
||||
else:
|
||||
return txt.decode("utf8")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -674,29 +683,32 @@ def read_cert_from_file(cert_file, cert_type):
|
||||
return ""
|
||||
|
||||
if cert_type == "pem":
|
||||
line = open(cert_file).read().replace("\r\n", "\n").split("\n")
|
||||
_a = read_file(cert_file, 'rb').decode("utf8")
|
||||
_b = _a.replace("\r\n", "\n")
|
||||
lines = _b.split("\n")
|
||||
|
||||
if line[0] == "-----BEGIN CERTIFICATE-----":
|
||||
line = line[1:]
|
||||
elif line[0] == "-----BEGIN PUBLIC KEY-----":
|
||||
line = line[1:]
|
||||
for pattern in ("-----BEGIN CERTIFICATE-----",
|
||||
"-----BEGIN PUBLIC KEY-----"):
|
||||
if pattern in lines:
|
||||
lines = lines[lines.index(pattern)+1:]
|
||||
break
|
||||
else:
|
||||
raise CertificateError("Strange beginning of PEM file")
|
||||
|
||||
while line[-1] == "":
|
||||
line = line[:-1]
|
||||
|
||||
if line[-1] == "-----END CERTIFICATE-----":
|
||||
line = line[:-1]
|
||||
elif line[-1] == "-----END PUBLIC KEY-----":
|
||||
line = line[:-1]
|
||||
for pattern in ("-----END CERTIFICATE-----",
|
||||
"-----END PUBLIC KEY-----"):
|
||||
if pattern in lines:
|
||||
lines = lines[:lines.index(pattern)]
|
||||
break
|
||||
else:
|
||||
raise CertificateError("Strange end of PEM file")
|
||||
return "".join(line)
|
||||
return make_str("".join(lines).encode("utf8"))
|
||||
|
||||
|
||||
if cert_type in ["der", "cer", "crt"]:
|
||||
data = read_file(cert_file)
|
||||
return base64.b64encode(str(data))
|
||||
data = read_file(cert_file, 'rb')
|
||||
_cert = base64.b64encode(data)
|
||||
return make_str(_cert)
|
||||
|
||||
|
||||
class CryptoBackend():
|
||||
@@ -850,8 +862,8 @@ class CryptoBackendXmlSec1(CryptoBackend):
|
||||
'id','Id' or 'ID'
|
||||
:return: The signed statement
|
||||
"""
|
||||
if not isinstance(statement, six.binary_type):
|
||||
statement = str(statement).encode('utf-8')
|
||||
if isinstance(statement, SamlBase):
|
||||
statement = str(statement)
|
||||
|
||||
_, fil = make_temp(statement, suffix=".xml",
|
||||
decode=False, delete=self._xmlsec_delete_tmpfiles)
|
||||
@@ -1284,8 +1296,6 @@ class SecurityContext(object):
|
||||
self.encryption_keypairs = encryption_keypairs
|
||||
self.enc_cert_type = enc_cert_type
|
||||
|
||||
|
||||
|
||||
self.my_cert = read_cert_from_file(cert_file, cert_type)
|
||||
|
||||
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file,
|
||||
|
@@ -1,17 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import base64
|
||||
from saml2.sigver import pre_encryption_part, make_temp, XmlsecError, \
|
||||
SigverError
|
||||
from saml2.mdstore import MetadataStore
|
||||
from saml2.saml import assertion_from_string, EncryptedAssertion
|
||||
from saml2.samlp import response_from_string
|
||||
|
||||
from saml2 import sigver, extension_elements_to_elements
|
||||
from saml2 import sigver
|
||||
from saml2 import extension_elements_to_elements
|
||||
from saml2 import class_name
|
||||
from saml2 import time_util
|
||||
from saml2 import saml, samlp
|
||||
from saml2 import config
|
||||
from saml2.sigver import pre_encryption_part
|
||||
from saml2.sigver import make_temp
|
||||
from saml2.sigver import XmlsecError
|
||||
from saml2.sigver import SigverError
|
||||
from saml2.mdstore import MetadataStore
|
||||
from saml2.saml import assertion_from_string
|
||||
from saml2.saml import EncryptedAssertion
|
||||
from saml2.samlp import response_from_string
|
||||
from saml2.s_utils import factory, do_attribute_statement
|
||||
|
||||
from py.test import raises
|
||||
@@ -510,6 +513,6 @@ def test_xmlsec_err():
|
||||
if __name__ == "__main__":
|
||||
t = TestSecurity()
|
||||
t.setup_class()
|
||||
t.test_verify_1()
|
||||
t.test_sign_assertion()
|
||||
|
||||
#test_xmlsec_err()
|
||||
|
@@ -25,6 +25,7 @@ from saml2.response import LogoutResponse
|
||||
from saml2.saml import NAMEID_FORMAT_PERSISTENT, EncryptedAssertion, Advice
|
||||
from saml2.saml import NAMEID_FORMAT_TRANSIENT
|
||||
from saml2.saml import NameID
|
||||
from saml2.samlp import SessionIndex
|
||||
from saml2.server import Server
|
||||
from saml2.sigver import pre_encryption_part, make_temp, pre_encrypt_assertion
|
||||
from saml2.sigver import rm_xmltag
|
||||
@@ -319,6 +320,19 @@ class TestClient:
|
||||
except Exception: # missing certificate
|
||||
self.client.sec.verify_signature(ar_str, node_name=class_name(ar))
|
||||
|
||||
def test_create_logout_request(self):
|
||||
req_id, req = self.client.create_logout_request(
|
||||
"http://localhost:8088/slo", "urn:mace:example.com:saml:roland:idp",
|
||||
name_id=nid, reason="Tired", expire=in_a_while(minutes=15),
|
||||
session_indexes=["_foo"])
|
||||
|
||||
assert req.destination == "http://localhost:8088/slo"
|
||||
assert req.reason == "Tired"
|
||||
assert req.version == "2.0"
|
||||
assert req.name_id == nid
|
||||
assert req.issuer.text == "urn:mace:example.com:saml:roland:sp"
|
||||
assert req.session_index == [SessionIndex("_foo")]
|
||||
|
||||
def test_response_1(self):
|
||||
IDP = "urn:mace:example.com:saml:roland:idp"
|
||||
|
||||
@@ -359,6 +373,7 @@ class TestClient:
|
||||
assert session_info["came_from"] == "http://foo.example.com/service"
|
||||
response = samlp.response_from_string(authn_response.xmlstr)
|
||||
assert response.destination == "http://lingon.catalogix.se:8087/"
|
||||
assert "session_index" in session_info
|
||||
|
||||
# One person in the cache
|
||||
assert len(self.client.users.subjects()) == 1
|
||||
@@ -1220,6 +1235,36 @@ class TestClient:
|
||||
BINDING_HTTP_REDIRECT)
|
||||
print(res)
|
||||
|
||||
def test_do_logout_post(self):
|
||||
# information about the user from an IdP
|
||||
session_info = {
|
||||
"name_id": nid,
|
||||
"issuer": "urn:mace:example.com:saml:roland:idp",
|
||||
"not_on_or_after": in_a_while(minutes=15),
|
||||
"ava": {
|
||||
"givenName": "Anders",
|
||||
"surName": "Andersson",
|
||||
"mail": "anders.andersson@example.com"
|
||||
},
|
||||
"session_index": SessionIndex("_foo")
|
||||
}
|
||||
self.client.users.add_information_about_person(session_info)
|
||||
entity_ids = self.client.users.issuers_of_info(nid)
|
||||
assert entity_ids == ["urn:mace:example.com:saml:roland:idp"]
|
||||
resp = self.client.do_logout(nid, entity_ids, "Tired",
|
||||
in_a_while(minutes=5), sign=True,
|
||||
expected_binding=BINDING_HTTP_POST)
|
||||
assert resp
|
||||
assert len(resp) == 1
|
||||
assert list(resp.keys()) == entity_ids
|
||||
binding, info = resp[entity_ids[0]]
|
||||
assert binding == BINDING_HTTP_POST
|
||||
|
||||
_dic = unpack_form(info["data"][3])
|
||||
res = self.server.parse_logout_request(_dic["SAMLRequest"],
|
||||
BINDING_HTTP_POST)
|
||||
assert b'<ns0:SessionIndex>_foo</ns0:SessionIndex>' in res.xmlstr
|
||||
|
||||
|
||||
# Below can only be done with dummy Server
|
||||
IDP = "urn:mace:example.com:saml:roland:idp"
|
||||
|
Reference in New Issue
Block a user