Made the test work.
This commit is contained in:
@@ -245,6 +245,11 @@ class Saml2Client(Base):
|
|||||||
response = self.send_using_soap(query, destination)
|
response = self.send_using_soap(query, destination)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
|
if not response_args:
|
||||||
|
response_args = {"binding": BINDING_SOAP}
|
||||||
|
else:
|
||||||
|
response_args["binding"] = BINDING_SOAP
|
||||||
|
|
||||||
logger.info("Verifying response")
|
logger.info("Verifying response")
|
||||||
if response_args:
|
if response_args:
|
||||||
response = _response_func(response, **response_args)
|
response = _response_func(response, **response_args)
|
||||||
|
@@ -364,24 +364,21 @@ class Base(Entity):
|
|||||||
extensions=extensions,
|
extensions=extensions,
|
||||||
sign=sign)
|
sign=sign)
|
||||||
|
|
||||||
def create_assertion_id_request(self, assertion_id_refs, destination=None,
|
def create_assertion_id_request(self, assertion_id_refs, **kwargs):
|
||||||
id=0, consent=None, extensions=None,
|
|
||||||
sign=False):
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param assertion_id_refs:
|
:param assertion_id_refs:
|
||||||
:param destination: The IdP endpoint to send the request to
|
:return: One ID ref
|
||||||
:param id: Message identifier
|
|
||||||
:param consent: If the principal gave her consent to this request
|
|
||||||
:param extensions: Possible request extensions
|
|
||||||
:param sign: Whether the request should be signed or not.
|
|
||||||
:return: AssertionIDRequest instance
|
|
||||||
"""
|
"""
|
||||||
id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs]
|
# id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs]
|
||||||
|
#
|
||||||
return self._message(AssertionIDRequest, destination, id, consent,
|
# return self._message(AssertionIDRequest, destination, id, consent,
|
||||||
extensions, sign, assertion_id_ref=id_refs )
|
# extensions, sign, assertion_id_ref=id_refs )
|
||||||
|
|
||||||
|
if isinstance(assertion_id_refs, basestring):
|
||||||
|
return assertion_id_refs
|
||||||
|
else:
|
||||||
|
return assertion_id_refs[0]
|
||||||
|
|
||||||
def create_authn_query(self, subject, destination=None,
|
def create_authn_query(self, subject, destination=None,
|
||||||
authn_context=None, session_index="",
|
authn_context=None, session_index="",
|
||||||
@@ -520,9 +517,12 @@ class Base(Entity):
|
|||||||
# ------------------------------------------------------------------------
|
# ------------------------------------------------------------------------
|
||||||
|
|
||||||
def parse_attribute_query_response(self, response, binding):
|
def parse_attribute_query_response(self, response, binding):
|
||||||
|
kwargs = {"entity_id": self.config.entityid,
|
||||||
|
"attribute_converters": self.config.attribute_converters}
|
||||||
|
|
||||||
return self._parse_response(response, AttributeResponse,
|
return self._parse_response(response, AttributeResponse,
|
||||||
"attribute_consuming_service", binding)
|
"attribute_consuming_service", binding,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def parse_name_id_mapping_request_response(self, txt, binding=BINDING_SOAP):
|
def parse_name_id_mapping_request_response(self, txt, binding=BINDING_SOAP):
|
||||||
"""
|
"""
|
||||||
@@ -531,4 +531,5 @@ class Base(Entity):
|
|||||||
:param binding: Just a placeholder, it's always BINDING_SOAP
|
:param binding: Just a placeholder, it's always BINDING_SOAP
|
||||||
:return: parsed and verified <NameIDMappingResponse> instance
|
:return: parsed and verified <NameIDMappingResponse> instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._parse_response(txt, NameIDMappingResponse, "", binding)
|
return self._parse_response(txt, NameIDMappingResponse, "", binding)
|
||||||
|
@@ -45,6 +45,9 @@ PAIRS = {
|
|||||||
class ConnectionError(Exception):
|
class ConnectionError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class HTTPError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
def _since_epoch(cdate):
|
def _since_epoch(cdate):
|
||||||
"""
|
"""
|
||||||
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
|
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
|
||||||
@@ -217,9 +220,7 @@ class HTTPBase(object):
|
|||||||
# msg should be an identifier
|
# msg should be an identifier
|
||||||
info = {
|
info = {
|
||||||
"data": "",
|
"data": "",
|
||||||
"headers": [
|
"url": "%s?ID=%s" % (destination, message)
|
||||||
("Location", "%s?ID=%s" % (destination, message))
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
@@ -273,8 +274,10 @@ class HTTPBase(object):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
xmlstr = response.text
|
if response.status_code == 200:
|
||||||
logger.info("SOAP response: %s" % xmlstr)
|
logger.info("SOAP response: %s" % response.text)
|
||||||
return response
|
return response.text
|
||||||
|
else:
|
||||||
|
raise HTTPError("%d:%s" % (response.status_code, response.error))
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
@@ -16,7 +16,8 @@ __author__ = 'rolandh'
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id"]
|
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id",
|
||||||
|
"text"]
|
||||||
|
|
||||||
class Unknown(Exception):
|
class Unknown(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -168,9 +169,13 @@ class IdentDB(object):
|
|||||||
return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT,
|
return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT,
|
||||||
sp_name_qualifier, name_qualifier)
|
sp_name_qualifier, name_qualifier)
|
||||||
|
|
||||||
def permanent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
|
def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
|
||||||
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT,
|
nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier)
|
||||||
sp_name_qualifier, name_qualifier)
|
if nameid:
|
||||||
|
return nameid
|
||||||
|
else:
|
||||||
|
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT,
|
||||||
|
sp_name_qualifier, name_qualifier)
|
||||||
|
|
||||||
def find_local_id(self, name_id):
|
def find_local_id(self, name_id):
|
||||||
"""
|
"""
|
||||||
@@ -191,8 +196,18 @@ class IdentDB(object):
|
|||||||
nid = decode(val)
|
nid = decode(val)
|
||||||
if nid.format == NAMEID_FORMAT_TRANSIENT:
|
if nid.format == NAMEID_FORMAT_TRANSIENT:
|
||||||
continue
|
continue
|
||||||
if getattr(nid, "sp_name_qualifier", "") == sp_name_qualifier:
|
snq = getattr(nid, "sp_name_qualifier", "")
|
||||||
if getattr(nid, "name_qualifier", "") == name_qualifier:
|
if snq and snq == sp_name_qualifier:
|
||||||
|
nq = getattr(nid, "name_qualifier", None)
|
||||||
|
if nq and nq == name_qualifier:
|
||||||
|
return nid
|
||||||
|
elif not nq and not name_qualifier:
|
||||||
|
return nid
|
||||||
|
elif not snq and not sp_name_qualifier:
|
||||||
|
nq = getattr(nid, "name_qualifier", None)
|
||||||
|
if nq and nq == name_qualifier:
|
||||||
|
return nid
|
||||||
|
elif not nq and not name_qualifier:
|
||||||
return nid
|
return nid
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
@@ -280,3 +295,6 @@ class IdentDB(object):
|
|||||||
return self.db["%s:%s" % (userid, entity_id)]
|
return self.db["%s:%s" % (userid, entity_id)]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.db.close()
|
||||||
|
@@ -106,7 +106,7 @@ class Server(Entity):
|
|||||||
def close_shelve_db(self):
|
def close_shelve_db(self):
|
||||||
"""Close the shelve db to prevent file system locking issues"""
|
"""Close the shelve db to prevent file system locking issues"""
|
||||||
if self.ident:
|
if self.ident:
|
||||||
self.ident.map.close()
|
self.ident.db.close()
|
||||||
|
|
||||||
def wants(self, sp_entity_id, index=None):
|
def wants(self, sp_entity_id, index=None):
|
||||||
""" Returns what attributes the SP requires and which are optional
|
""" Returns what attributes the SP requires and which are optional
|
||||||
|
@@ -56,6 +56,11 @@ def parse_soap_enveloped_saml_attribute_query(text):
|
|||||||
expected_tag = '{%s}AttributeQuery' % SAMLP_NAMESPACE
|
expected_tag = '{%s}AttributeQuery' % SAMLP_NAMESPACE
|
||||||
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
|
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
|
||||||
|
|
||||||
|
def parse_soap_enveloped_saml_attribute_response(text):
|
||||||
|
tags = ['{%s}Response' % SAMLP_NAMESPACE,
|
||||||
|
'{%s}AttributeResponse' % SAMLP_NAMESPACE]
|
||||||
|
return parse_soap_enveloped_saml_thingy(text, tags)
|
||||||
|
|
||||||
def parse_soap_enveloped_saml_logout_request(text):
|
def parse_soap_enveloped_saml_logout_request(text):
|
||||||
expected_tag = '{%s}LogoutRequest' % SAMLP_NAMESPACE
|
expected_tag = '{%s}LogoutRequest' % SAMLP_NAMESPACE
|
||||||
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
|
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
|
||||||
|
@@ -34,6 +34,12 @@ def unpack_form(_str, ver="SAMLRequest"):
|
|||||||
|
|
||||||
return {ver:sr, "RelayState":rs}
|
return {ver:sr, "RelayState":rs}
|
||||||
|
|
||||||
|
class DummyResponse(object):
|
||||||
|
def __init__(self, code, data, headers=None):
|
||||||
|
self.status_code = code
|
||||||
|
self.text = data
|
||||||
|
self.headers = headers or []
|
||||||
|
|
||||||
class FakeIDP(Server):
|
class FakeIDP(Server):
|
||||||
def __init__(self, config_file=""):
|
def __init__(self, config_file=""):
|
||||||
Server.__init__(self, config_file)
|
Server.__init__(self, config_file)
|
||||||
@@ -52,12 +58,19 @@ class FakeIDP(Server):
|
|||||||
if method == "GET":
|
if method == "GET":
|
||||||
path, query = url.split("?")
|
path, query = url.split("?")
|
||||||
qs_dict = parse_qs(kwargs["data"])
|
qs_dict = parse_qs(kwargs["data"])
|
||||||
|
req = qs_dict["SAMLRequest"][0]
|
||||||
|
rstate = qs_dict["RelayState"][0]
|
||||||
else:
|
else:
|
||||||
|
# Could be either POST or SOAP
|
||||||
path = url
|
path = url
|
||||||
qs_dict = parse_qs(kwargs["data"])
|
try:
|
||||||
|
qs_dict = parse_qs(kwargs["data"])
|
||||||
|
req = qs_dict["SAMLRequest"][0]
|
||||||
|
rstate = qs_dict["RelayState"][0]
|
||||||
|
except KeyError:
|
||||||
|
req = kwargs["data"]
|
||||||
|
rstate = ""
|
||||||
|
|
||||||
req = qs_dict["SAMLRequest"][0]
|
|
||||||
rstate = qs_dict["RelayState"][0]
|
|
||||||
response = ""
|
response = ""
|
||||||
|
|
||||||
# Get service from path
|
# Get service from path
|
||||||
@@ -105,9 +118,10 @@ class FakeIDP(Server):
|
|||||||
|
|
||||||
response = "%s" % authn_resp
|
response = "%s" % authn_resp
|
||||||
|
|
||||||
return pack.factory(_binding, response,
|
_dict = pack.factory(_binding, response,
|
||||||
resp_args["destination"], relay_state,
|
resp_args["destination"], relay_state,
|
||||||
"SAMLResponse")
|
"SAMLResponse")
|
||||||
|
return DummyResponse(200, **_dict)
|
||||||
|
|
||||||
def attribute_query_endpoint(self, xml_str, binding):
|
def attribute_query_endpoint(self, xml_str, binding):
|
||||||
if binding == BINDING_SOAP:
|
if binding == BINDING_SOAP:
|
||||||
@@ -139,7 +153,7 @@ class FakeIDP(Server):
|
|||||||
else: # Just POST
|
else: # Just POST
|
||||||
response = "%s" % attr_resp
|
response = "%s" % attr_resp
|
||||||
|
|
||||||
return response
|
return DummyResponse(200, response)
|
||||||
|
|
||||||
def logout_endpoint(self, xml_str, binding):
|
def logout_endpoint(self, xml_str, binding):
|
||||||
if binding == BINDING_SOAP:
|
if binding == BINDING_SOAP:
|
||||||
@@ -164,4 +178,4 @@ class FakeIDP(Server):
|
|||||||
else: # Just POST
|
else: # Just POST
|
||||||
response = "%s" % _resp
|
response = "%s" % _resp
|
||||||
|
|
||||||
return response
|
return DummyResponse(200, response)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
import os
|
||||||
|
|
||||||
from saml2 import samlp
|
from saml2 import samlp
|
||||||
from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT
|
from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT
|
||||||
@@ -143,3 +144,32 @@ class TestIdentifier():
|
|||||||
assert nameid.format == NAMEID_FORMAT_PERSISTENT
|
assert nameid.format == NAMEID_FORMAT_PERSISTENT
|
||||||
assert nameid.text != "foobar01"
|
assert nameid.text != "foobar01"
|
||||||
|
|
||||||
|
|
||||||
|
def test_persistent_nameid(self):
|
||||||
|
sp_id = "urn:mace:umu.se:sp"
|
||||||
|
nameid = self.id.persistent_nameid("abcd0001", sp_id)
|
||||||
|
remote_id = nameid.text.strip()
|
||||||
|
print remote_id
|
||||||
|
local = self.id.find_local_id(nameid)
|
||||||
|
assert local == "abcd0001"
|
||||||
|
|
||||||
|
# Always get the same
|
||||||
|
nameid2 = self.id.persistent_nameid("abcd0001", sp_id)
|
||||||
|
assert nameid.text.strip() == nameid2.text.strip()
|
||||||
|
|
||||||
|
def test_transient_nameid(self):
|
||||||
|
sp_id = "urn:mace:umu.se:sp"
|
||||||
|
nameid = self.id.transient_nameid("abcd0001", sp_id)
|
||||||
|
remote_id = nameid.text.strip()
|
||||||
|
print remote_id
|
||||||
|
local = self.id.find_local_id(nameid)
|
||||||
|
assert local == "abcd0001"
|
||||||
|
|
||||||
|
# Getting a new, means really getting a new !
|
||||||
|
nameid2 = self.id.transient_nameid(sp_id, "abcd0001")
|
||||||
|
assert nameid.text.strip() != nameid2.text.strip()
|
||||||
|
|
||||||
|
def teardown_class(self):
|
||||||
|
if os.path.exists("foobar.db"):
|
||||||
|
os.unlink("foobar.db")
|
||||||
|
|
||||||
|
@@ -5,7 +5,8 @@ from urlparse import parse_qs
|
|||||||
from saml2.saml import AUTHN_PASSWORD
|
from saml2.saml import AUTHN_PASSWORD
|
||||||
from saml2.samlp import response_from_string
|
from saml2.samlp import response_from_string
|
||||||
|
|
||||||
from saml2.server import Server, Identifier
|
from saml2.server import Server
|
||||||
|
from saml2.ident import IdentDB
|
||||||
from saml2 import samlp, saml, client, config
|
from saml2 import samlp, saml, client, config
|
||||||
from saml2 import s_utils
|
from saml2 import s_utils
|
||||||
from saml2 import sigver
|
from saml2 import sigver
|
||||||
@@ -21,45 +22,6 @@ import os
|
|||||||
def _eq(l1,l2):
|
def _eq(l1,l2):
|
||||||
return set(l1) == set(l2)
|
return set(l1) == set(l2)
|
||||||
|
|
||||||
class TestIdentifier():
|
|
||||||
def setup_class(self):
|
|
||||||
self.ident = Identifier("foobar.db")
|
|
||||||
|
|
||||||
def test_persistent_nameid(self):
|
|
||||||
sp_id = "urn:mace:umu.se:sp"
|
|
||||||
nameid = self.ident.persistent_nameid(sp_id, "abcd0001")
|
|
||||||
remote_id = nameid.text.strip()
|
|
||||||
print remote_id
|
|
||||||
print self.ident.map
|
|
||||||
local = self.ident.local_name(sp_id, remote_id)
|
|
||||||
assert local == "abcd0001"
|
|
||||||
assert self.ident.local_name(sp_id, "pseudo random string") is None
|
|
||||||
assert self.ident.local_name(sp_id+":x", remote_id) is None
|
|
||||||
|
|
||||||
# Always get the same
|
|
||||||
nameid2 = self.ident.persistent_nameid(sp_id, "abcd0001")
|
|
||||||
assert nameid.text.strip() == nameid2.text.strip()
|
|
||||||
|
|
||||||
def test_transient_nameid(self):
|
|
||||||
sp_id = "urn:mace:umu.se:sp"
|
|
||||||
nameid = self.ident.transient_nameid(sp_id, "abcd0001")
|
|
||||||
remote_id = nameid.text.strip()
|
|
||||||
print remote_id
|
|
||||||
print self.ident.map
|
|
||||||
local = self.ident.local_name(sp_id, remote_id)
|
|
||||||
assert local == "abcd0001"
|
|
||||||
assert self.ident.local_name(sp_id, "pseudo random string") is None
|
|
||||||
assert self.ident.local_name(sp_id+":x", remote_id) is None
|
|
||||||
|
|
||||||
# Getting a new, means really getting a new !
|
|
||||||
nameid2 = self.ident.transient_nameid(sp_id, "abcd0001")
|
|
||||||
assert nameid.text.strip() != nameid2.text.strip()
|
|
||||||
|
|
||||||
def teardown_class(self):
|
|
||||||
if os.path.exists("foobar.db"):
|
|
||||||
os.unlink("foobar.db")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestServer1():
|
class TestServer1():
|
||||||
def setup_class(self):
|
def setup_class(self):
|
||||||
|
@@ -406,8 +406,8 @@ class TestClientWithDummy():
|
|||||||
http_args["headers"] = [('Content-type','application/x-www-form-urlencoded')]
|
http_args["headers"] = [('Content-type','application/x-www-form-urlencoded')]
|
||||||
|
|
||||||
response = self.client.send(**http_args)
|
response = self.client.send(**http_args)
|
||||||
|
print response.text
|
||||||
_dic = unpack_form(response["data"][3], "SAMLResponse")
|
_dic = unpack_form(response.text[3], "SAMLResponse")
|
||||||
resp = self.client.parse_authn_request_response(_dic["SAMLResponse"],
|
resp = self.client.parse_authn_request_response(_dic["SAMLResponse"],
|
||||||
BINDING_HTTP_POST,
|
BINDING_HTTP_POST,
|
||||||
{id: "/"})
|
{id: "/"})
|
||||||
|
@@ -28,9 +28,7 @@ def get_msg(hinfo, binding, response=False):
|
|||||||
msg = hinfo["data"]
|
msg = hinfo["data"]
|
||||||
else:
|
else:
|
||||||
msg = ""
|
msg = ""
|
||||||
for header, val in hinfo["headers"]:
|
return parse_qs(hinfo["url"].split("?")[1])["ID"][0]
|
||||||
if header == "Location":
|
|
||||||
return parse_qs(val.split("?")[1])["ID"][0]
|
|
||||||
else: # BINDING_HTTP_REDIRECT
|
else: # BINDING_HTTP_REDIRECT
|
||||||
parts = urlparse(hinfo["headers"][0][1])
|
parts = urlparse(hinfo["headers"][0][1])
|
||||||
msg = parse_qs(parts.query)["SAMLRequest"][0]
|
msg = parse_qs(parts.query)["SAMLRequest"][0]
|
||||||
|
Reference in New Issue
Block a user