Fixed various errors.

This commit is contained in:
Roland Hedberg
2015-10-04 10:15:13 -04:00
parent dfad000305
commit a12cc2a979
6 changed files with 125 additions and 53 deletions

View File

@@ -206,9 +206,14 @@ class Saml2Client(Base):
destination = destinations(srvs)[0]
logger.info("destination to provider: %s" % destination)
try:
session_info = self.users.get_info_from(name_id, entity_id)
session_indexes = [session_info['session_index']]
except KeyError:
session_indexes = None
req_id, request = self.create_logout_request(
destination, entity_id, name_id=name_id, reason=reason,
expire=expire)
expire=expire, session_indexes=session_indexes)
# to_sign = []
if binding.startswith("http://"):

View File

@@ -59,16 +59,17 @@ bMDNS = b'"urn:oasis:names:tc:SAML:2.0:metadata"'
XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
bXMLNSXS = b" xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
def metadata_tostring_fix(desc, nspair, xmlstring=""):
if not xmlstring:
xmlstring = desc.to_string(nspair)
if six.PY2:
if "\"xs:string\"" in xmlstring and XMLNSXS not in xmlstring:
xmlstring = xmlstring.replace(MDNS, MDNS+XMLNSXS)
xmlstring = xmlstring.replace(MDNS, MDNS + XMLNSXS)
else:
if b"\"xs:string\"" in xmlstring and bXMLNSXS not in xmlstring:
xmlstring = xmlstring.replace(bMDNS, bMDNS+bXMLNSXS)
xmlstring = xmlstring.replace(bMDNS, bMDNS + bXMLNSXS)
return xmlstring
@@ -77,7 +78,7 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
keyfile=None, mid=None, name=None, sign=None):
valid_for = 0
nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
#paths = [".", "/opt/local/bin"]
# paths = [".", "/opt/local/bin"]
if valid:
valid_for = int(valid) # Hours
@@ -97,11 +98,8 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
secc = security_context(conf)
if mid:
desc = entities_descriptor(eds, valid_for, name, mid,
sign, secc)
valid_instance(desc)
return metadata_tostring_fix(desc, nspair)
eid, xmldoc = entities_descriptor(eds, valid_for, name, mid,
sign, secc)
else:
eid = eds[0]
if sign:
@@ -109,9 +107,8 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
else:
xmldoc = None
valid_instance(eid)
xmldoc = metadata_tostring_fix(eid, nspair, xmldoc)
return xmldoc
valid_instance(eid)
return metadata_tostring_fix(eid, nspair, xmldoc)
def _localized_name(val, klass):
@@ -346,6 +343,7 @@ def do_idpdisc(discovery_response):
return idpdisc.DiscoveryResponse(index="0", location=discovery_response,
binding=idpdisc.NAMESPACE)
ENDPOINTS = {
"sp": {
"artifact_resolution_service": (md.ArtifactResolutionService, True),
@@ -425,7 +423,8 @@ def do_endpoints(conf, endpoints):
servs = []
i = 1
for args in conf[endpoint]:
if isinstance(args, six.string_types): # Assume it's the location
if isinstance(args,
six.string_types): # Assume it's the location
args = {"location": args,
"binding": DEFAULT_BINDING[endpoint]}
elif isinstance(args, tuple) or isinstance(args, list):
@@ -453,16 +452,16 @@ def do_endpoints(conf, endpoints):
pass
return service
DEFAULT = {
"want_assertions_signed": "true",
"authn_requests_signed": "false",
"want_authn_requests_signed": "false",
#"want_authn_requests_only_with_valid_cert": "false",
# "want_authn_requests_only_with_valid_cert": "false",
}
def do_attribute_consuming_service(conf, spsso):
service_description = service_name = None
requested_attributes = []
acs = conf.attribute_converters
@@ -557,7 +556,8 @@ def do_spsso_descriptor(conf, cert=None, enc_cert=None):
if cert or enc_cert:
metadata_key_usage = conf.metadata_key_usage
spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert, use=metadata_key_usage)
spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert,
use=metadata_key_usage)
for key in ["want_assertions_signed", "authn_requests_signed"]:
try:
@@ -605,10 +605,11 @@ def do_idpsso_descriptor(conf, cert=None, enc_cert=None):
idpsso.extensions.add_extension_element(do_uiinfo(ui_info))
if cert or enc_cert:
idpsso.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
idpsso.key_descriptor = do_key_descriptor(cert, enc_cert,
use=conf.metadata_key_usage)
for key in ["want_authn_requests_signed"]:
#"want_authn_requests_only_with_valid_cert"]:
# "want_authn_requests_only_with_valid_cert"]:
try:
val = conf.getattr(key, "idp")
if val is None:
@@ -635,7 +636,8 @@ def do_aa_descriptor(conf, cert=None, enc_cert=None):
_do_nameid_format(aad, conf, "aa")
if cert or enc_cert:
aad.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
aad.key_descriptor = do_key_descriptor(cert, enc_cert,
use=conf.metadata_key_usage)
attributes = conf.getattr("attribute", "aa")
if attributes:
@@ -664,7 +666,8 @@ def do_aq_descriptor(conf, cert=None, enc_cert=None):
_do_nameid_format(aqs, conf, "aq")
if cert or enc_cert:
aqs.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
aqs.key_descriptor = do_key_descriptor(cert, enc_cert,
use=conf.metadata_key_usage)
return aqs
@@ -685,7 +688,8 @@ def do_pdp_descriptor(conf, cert=None, enc_cert=None):
_do_nameid_format(pdp, conf, "pdp")
if cert:
pdp.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
pdp.key_descriptor = do_key_descriptor(cert, enc_cert,
use=conf.metadata_key_usage)
return pdp
@@ -702,7 +706,8 @@ def entity_descriptor(confd):
if confd.encryption_keypairs is not None:
enc_cert = []
for _encryption in confd.encryption_keypairs:
enc_cert.append("".join(open(_encryption["cert_file"]).readlines()[1:-1]))
enc_cert.append(
"".join(open(_encryption["cert_file"]).readlines()[1:-1]))
entd = md.EntityDescriptor()
entd.entity_id = confd.entityid
@@ -736,13 +741,15 @@ def entity_descriptor(confd):
entd.idpsso_descriptor = do_idpsso_descriptor(confd, mycert, enc_cert)
if "aa" in serves:
confd.context = "aa"
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert, enc_cert)
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert,
enc_cert)
if "pdp" in serves:
confd.context = "pdp"
entd.pdp_descriptor = do_pdp_descriptor(confd, mycert, enc_cert)
if "aq" in serves:
confd.context = "aq"
entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert, enc_cert)
entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert,
enc_cert)
return entd

View File

@@ -1036,9 +1036,11 @@ class AuthnResponse(StatusResponse):
"issuer": self.issuer(), "not_on_or_after": nooa,
"authz_decision_info": self.authz_decision_info()}
else:
authn_statement = self.assertion.authn_statement[0]
return {"ava": self.ava, "name_id": self.name_id,
"came_from": self.came_from, "issuer": self.issuer(),
"not_on_or_after": nooa, "authn_info": self.authn_info()}
"not_on_or_after": nooa, "authn_info": self.authn_info(),
"session_index": authn_statement.session_index}
def __str__(self):
if not isinstance(self.xmlstr, six.string_types):

View File

@@ -353,7 +353,10 @@ def make_temp(string, suffix="", decode=True, delete=True):
xmlsec function).
"""
ntf = NamedTemporaryFile(suffix=suffix, delete=delete)
assert isinstance(string, six.binary_type)
# Python3 tempfile requires byte-like object
if not isinstance(string, six.binary_type):
string = string.encode("utf8")
if decode:
ntf.write(base64.b64decode(string))
else:
@@ -657,6 +660,12 @@ LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
def make_str(txt):
if isinstance(txt, six.string_types):
return txt
else:
return txt.decode("utf8")
# ---------------------------------------------------------------------------
@@ -674,29 +683,32 @@ def read_cert_from_file(cert_file, cert_type):
return ""
if cert_type == "pem":
line = open(cert_file).read().replace("\r\n", "\n").split("\n")
_a = read_file(cert_file, 'rb').decode("utf8")
_b = _a.replace("\r\n", "\n")
lines = _b.split("\n")
if line[0] == "-----BEGIN CERTIFICATE-----":
line = line[1:]
elif line[0] == "-----BEGIN PUBLIC KEY-----":
line = line[1:]
for pattern in ("-----BEGIN CERTIFICATE-----",
"-----BEGIN PUBLIC KEY-----"):
if pattern in lines:
lines = lines[lines.index(pattern)+1:]
break
else:
raise CertificateError("Strange beginning of PEM file")
while line[-1] == "":
line = line[:-1]
if line[-1] == "-----END CERTIFICATE-----":
line = line[:-1]
elif line[-1] == "-----END PUBLIC KEY-----":
line = line[:-1]
for pattern in ("-----END CERTIFICATE-----",
"-----END PUBLIC KEY-----"):
if pattern in lines:
lines = lines[:lines.index(pattern)]
break
else:
raise CertificateError("Strange end of PEM file")
return "".join(line)
return make_str("".join(lines).encode("utf8"))
if cert_type in ["der", "cer", "crt"]:
data = read_file(cert_file)
return base64.b64encode(str(data))
data = read_file(cert_file, 'rb')
_cert = base64.b64encode(data)
return make_str(_cert)
class CryptoBackend():
@@ -850,8 +862,8 @@ class CryptoBackendXmlSec1(CryptoBackend):
'id','Id' or 'ID'
:return: The signed statement
"""
if not isinstance(statement, six.binary_type):
statement = str(statement).encode('utf-8')
if isinstance(statement, SamlBase):
statement = str(statement)
_, fil = make_temp(statement, suffix=".xml",
decode=False, delete=self._xmlsec_delete_tmpfiles)
@@ -1284,8 +1296,6 @@ class SecurityContext(object):
self.encryption_keypairs = encryption_keypairs
self.enc_cert_type = enc_cert_type
self.my_cert = read_cert_from_file(cert_file, cert_type)
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file,

View File

@@ -1,17 +1,20 @@
#!/usr/bin/env python
import base64
from saml2.sigver import pre_encryption_part, make_temp, XmlsecError, \
SigverError
from saml2.mdstore import MetadataStore
from saml2.saml import assertion_from_string, EncryptedAssertion
from saml2.samlp import response_from_string
from saml2 import sigver, extension_elements_to_elements
from saml2 import sigver
from saml2 import extension_elements_to_elements
from saml2 import class_name
from saml2 import time_util
from saml2 import saml, samlp
from saml2 import config
from saml2.sigver import pre_encryption_part
from saml2.sigver import make_temp
from saml2.sigver import XmlsecError
from saml2.sigver import SigverError
from saml2.mdstore import MetadataStore
from saml2.saml import assertion_from_string
from saml2.saml import EncryptedAssertion
from saml2.samlp import response_from_string
from saml2.s_utils import factory, do_attribute_statement
from py.test import raises
@@ -510,6 +513,6 @@ def test_xmlsec_err():
if __name__ == "__main__":
t = TestSecurity()
t.setup_class()
t.test_verify_1()
t.test_sign_assertion()
#test_xmlsec_err()

View File

@@ -25,6 +25,7 @@ from saml2.response import LogoutResponse
from saml2.saml import NAMEID_FORMAT_PERSISTENT, EncryptedAssertion, Advice
from saml2.saml import NAMEID_FORMAT_TRANSIENT
from saml2.saml import NameID
from saml2.samlp import SessionIndex
from saml2.server import Server
from saml2.sigver import pre_encryption_part, make_temp, pre_encrypt_assertion
from saml2.sigver import rm_xmltag
@@ -319,6 +320,19 @@ class TestClient:
except Exception: # missing certificate
self.client.sec.verify_signature(ar_str, node_name=class_name(ar))
def test_create_logout_request(self):
req_id, req = self.client.create_logout_request(
"http://localhost:8088/slo", "urn:mace:example.com:saml:roland:idp",
name_id=nid, reason="Tired", expire=in_a_while(minutes=15),
session_indexes=["_foo"])
assert req.destination == "http://localhost:8088/slo"
assert req.reason == "Tired"
assert req.version == "2.0"
assert req.name_id == nid
assert req.issuer.text == "urn:mace:example.com:saml:roland:sp"
assert req.session_index == [SessionIndex("_foo")]
def test_response_1(self):
IDP = "urn:mace:example.com:saml:roland:idp"
@@ -359,6 +373,7 @@ class TestClient:
assert session_info["came_from"] == "http://foo.example.com/service"
response = samlp.response_from_string(authn_response.xmlstr)
assert response.destination == "http://lingon.catalogix.se:8087/"
assert "session_index" in session_info
# One person in the cache
assert len(self.client.users.subjects()) == 1
@@ -1220,6 +1235,36 @@ class TestClient:
BINDING_HTTP_REDIRECT)
print(res)
def test_do_logout_post(self):
# information about the user from an IdP
session_info = {
"name_id": nid,
"issuer": "urn:mace:example.com:saml:roland:idp",
"not_on_or_after": in_a_while(minutes=15),
"ava": {
"givenName": "Anders",
"surName": "Andersson",
"mail": "anders.andersson@example.com"
},
"session_index": SessionIndex("_foo")
}
self.client.users.add_information_about_person(session_info)
entity_ids = self.client.users.issuers_of_info(nid)
assert entity_ids == ["urn:mace:example.com:saml:roland:idp"]
resp = self.client.do_logout(nid, entity_ids, "Tired",
in_a_while(minutes=5), sign=True,
expected_binding=BINDING_HTTP_POST)
assert resp
assert len(resp) == 1
assert list(resp.keys()) == entity_ids
binding, info = resp[entity_ids[0]]
assert binding == BINDING_HTTP_POST
_dic = unpack_form(info["data"][3])
res = self.server.parse_logout_request(_dic["SAMLRequest"],
BINDING_HTTP_POST)
assert b'<ns0:SessionIndex>_foo</ns0:SessionIndex>' in res.xmlstr
# Below can only be done with dummy Server
IDP = "urn:mace:example.com:saml:roland:idp"