Made the test work.
This commit is contained in:
parent
3624c769a4
commit
21880631d9
@ -245,6 +245,11 @@ class Saml2Client(Base):
|
||||
response = self.send_using_soap(query, destination)
|
||||
|
||||
if response:
|
||||
if not response_args:
|
||||
response_args = {"binding": BINDING_SOAP}
|
||||
else:
|
||||
response_args["binding"] = BINDING_SOAP
|
||||
|
||||
logger.info("Verifying response")
|
||||
if response_args:
|
||||
response = _response_func(response, **response_args)
|
||||
|
@ -364,24 +364,21 @@ class Base(Entity):
|
||||
extensions=extensions,
|
||||
sign=sign)
|
||||
|
||||
def create_assertion_id_request(self, assertion_id_refs, destination=None,
|
||||
id=0, consent=None, extensions=None,
|
||||
sign=False):
|
||||
def create_assertion_id_request(self, assertion_id_refs, **kwargs):
|
||||
"""
|
||||
|
||||
:param assertion_id_refs:
|
||||
:param destination: The IdP endpoint to send the request to
|
||||
:param id: Message identifier
|
||||
:param consent: If the principal gave her consent to this request
|
||||
:param extensions: Possible request extensions
|
||||
:param sign: Whether the request should be signed or not.
|
||||
:return: AssertionIDRequest instance
|
||||
:return: One ID ref
|
||||
"""
|
||||
id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs]
|
||||
|
||||
return self._message(AssertionIDRequest, destination, id, consent,
|
||||
extensions, sign, assertion_id_ref=id_refs )
|
||||
# id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs]
|
||||
#
|
||||
# return self._message(AssertionIDRequest, destination, id, consent,
|
||||
# extensions, sign, assertion_id_ref=id_refs )
|
||||
|
||||
if isinstance(assertion_id_refs, basestring):
|
||||
return assertion_id_refs
|
||||
else:
|
||||
return assertion_id_refs[0]
|
||||
|
||||
def create_authn_query(self, subject, destination=None,
|
||||
authn_context=None, session_index="",
|
||||
@ -516,13 +513,16 @@ class Base(Entity):
|
||||
res = self._parse_response(response, AssertionIDResponse, "", binding,
|
||||
**kwargs)
|
||||
return res
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def parse_attribute_query_response(self, response, binding):
|
||||
kwargs = {"entity_id": self.config.entityid,
|
||||
"attribute_converters": self.config.attribute_converters}
|
||||
|
||||
return self._parse_response(response, AttributeResponse,
|
||||
"attribute_consuming_service", binding)
|
||||
"attribute_consuming_service", binding,
|
||||
**kwargs)
|
||||
|
||||
def parse_name_id_mapping_request_response(self, txt, binding=BINDING_SOAP):
|
||||
"""
|
||||
@ -531,4 +531,5 @@ class Base(Entity):
|
||||
:param binding: Just a placeholder, it's always BINDING_SOAP
|
||||
:return: parsed and verified <NameIDMappingResponse> instance
|
||||
"""
|
||||
|
||||
return self._parse_response(txt, NameIDMappingResponse, "", binding)
|
||||
|
@ -45,6 +45,9 @@ PAIRS = {
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
class HTTPError(Exception):
|
||||
pass
|
||||
|
||||
def _since_epoch(cdate):
|
||||
"""
|
||||
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
|
||||
@ -217,9 +220,7 @@ class HTTPBase(object):
|
||||
# msg should be an identifier
|
||||
info = {
|
||||
"data": "",
|
||||
"headers": [
|
||||
("Location", "%s?ID=%s" % (destination, message))
|
||||
]
|
||||
"url": "%s?ID=%s" % (destination, message)
|
||||
}
|
||||
else:
|
||||
raise NotImplemented
|
||||
@ -273,8 +274,10 @@ class HTTPBase(object):
|
||||
raise
|
||||
|
||||
if response:
|
||||
xmlstr = response.text
|
||||
logger.info("SOAP response: %s" % xmlstr)
|
||||
return response
|
||||
if response.status_code == 200:
|
||||
logger.info("SOAP response: %s" % response.text)
|
||||
return response.text
|
||||
else:
|
||||
raise HTTPError("%d:%s" % (response.status_code, response.error))
|
||||
else:
|
||||
return None
|
||||
|
@ -16,7 +16,8 @@ __author__ = 'rolandh'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id"]
|
||||
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id",
|
||||
"text"]
|
||||
|
||||
class Unknown(Exception):
|
||||
pass
|
||||
@ -168,9 +169,13 @@ class IdentDB(object):
|
||||
return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT,
|
||||
sp_name_qualifier, name_qualifier)
|
||||
|
||||
def permanent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
|
||||
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT,
|
||||
sp_name_qualifier, name_qualifier)
|
||||
def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
|
||||
nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier)
|
||||
if nameid:
|
||||
return nameid
|
||||
else:
|
||||
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT,
|
||||
sp_name_qualifier, name_qualifier)
|
||||
|
||||
def find_local_id(self, name_id):
|
||||
"""
|
||||
@ -191,8 +196,18 @@ class IdentDB(object):
|
||||
nid = decode(val)
|
||||
if nid.format == NAMEID_FORMAT_TRANSIENT:
|
||||
continue
|
||||
if getattr(nid, "sp_name_qualifier", "") == sp_name_qualifier:
|
||||
if getattr(nid, "name_qualifier", "") == name_qualifier:
|
||||
snq = getattr(nid, "sp_name_qualifier", "")
|
||||
if snq and snq == sp_name_qualifier:
|
||||
nq = getattr(nid, "name_qualifier", None)
|
||||
if nq and nq == name_qualifier:
|
||||
return nid
|
||||
elif not nq and not name_qualifier:
|
||||
return nid
|
||||
elif not snq and not sp_name_qualifier:
|
||||
nq = getattr(nid, "name_qualifier", None)
|
||||
if nq and nq == name_qualifier:
|
||||
return nid
|
||||
elif not nq and not name_qualifier:
|
||||
return nid
|
||||
except KeyError:
|
||||
pass
|
||||
@ -279,4 +294,7 @@ class IdentDB(object):
|
||||
try:
|
||||
return self.db["%s:%s" % (userid, entity_id)]
|
||||
except KeyError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
self.db.close()
|
||||
|
@ -106,7 +106,7 @@ class Server(Entity):
|
||||
def close_shelve_db(self):
|
||||
"""Close the shelve db to prevent file system locking issues"""
|
||||
if self.ident:
|
||||
self.ident.map.close()
|
||||
self.ident.db.close()
|
||||
|
||||
def wants(self, sp_entity_id, index=None):
|
||||
""" Returns what attributes the SP requires and which are optional
|
||||
|
@ -56,6 +56,11 @@ def parse_soap_enveloped_saml_attribute_query(text):
|
||||
expected_tag = '{%s}AttributeQuery' % SAMLP_NAMESPACE
|
||||
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
|
||||
|
||||
def parse_soap_enveloped_saml_attribute_response(text):
|
||||
tags = ['{%s}Response' % SAMLP_NAMESPACE,
|
||||
'{%s}AttributeResponse' % SAMLP_NAMESPACE]
|
||||
return parse_soap_enveloped_saml_thingy(text, tags)
|
||||
|
||||
def parse_soap_enveloped_saml_logout_request(text):
|
||||
expected_tag = '{%s}LogoutRequest' % SAMLP_NAMESPACE
|
||||
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
|
||||
|
@ -34,6 +34,12 @@ def unpack_form(_str, ver="SAMLRequest"):
|
||||
|
||||
return {ver:sr, "RelayState":rs}
|
||||
|
||||
class DummyResponse(object):
|
||||
def __init__(self, code, data, headers=None):
|
||||
self.status_code = code
|
||||
self.text = data
|
||||
self.headers = headers or []
|
||||
|
||||
class FakeIDP(Server):
|
||||
def __init__(self, config_file=""):
|
||||
Server.__init__(self, config_file)
|
||||
@ -52,12 +58,19 @@ class FakeIDP(Server):
|
||||
if method == "GET":
|
||||
path, query = url.split("?")
|
||||
qs_dict = parse_qs(kwargs["data"])
|
||||
req = qs_dict["SAMLRequest"][0]
|
||||
rstate = qs_dict["RelayState"][0]
|
||||
else:
|
||||
# Could be either POST or SOAP
|
||||
path = url
|
||||
qs_dict = parse_qs(kwargs["data"])
|
||||
try:
|
||||
qs_dict = parse_qs(kwargs["data"])
|
||||
req = qs_dict["SAMLRequest"][0]
|
||||
rstate = qs_dict["RelayState"][0]
|
||||
except KeyError:
|
||||
req = kwargs["data"]
|
||||
rstate = ""
|
||||
|
||||
req = qs_dict["SAMLRequest"][0]
|
||||
rstate = qs_dict["RelayState"][0]
|
||||
response = ""
|
||||
|
||||
# Get service from path
|
||||
@ -105,9 +118,10 @@ class FakeIDP(Server):
|
||||
|
||||
response = "%s" % authn_resp
|
||||
|
||||
return pack.factory(_binding, response,
|
||||
_dict = pack.factory(_binding, response,
|
||||
resp_args["destination"], relay_state,
|
||||
"SAMLResponse")
|
||||
return DummyResponse(200, **_dict)
|
||||
|
||||
def attribute_query_endpoint(self, xml_str, binding):
|
||||
if binding == BINDING_SOAP:
|
||||
@ -139,7 +153,7 @@ class FakeIDP(Server):
|
||||
else: # Just POST
|
||||
response = "%s" % attr_resp
|
||||
|
||||
return response
|
||||
return DummyResponse(200, response)
|
||||
|
||||
def logout_endpoint(self, xml_str, binding):
|
||||
if binding == BINDING_SOAP:
|
||||
@ -164,4 +178,4 @@ class FakeIDP(Server):
|
||||
else: # Just POST
|
||||
response = "%s" % _resp
|
||||
|
||||
return response
|
||||
return DummyResponse(200, response)
|
||||
|
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
|
||||
from saml2 import samlp
|
||||
from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT
|
||||
@ -142,4 +143,33 @@ class TestIdentifier():
|
||||
assert nameid.sp_name_qualifier == 'http://vo.example.org/design'
|
||||
assert nameid.format == NAMEID_FORMAT_PERSISTENT
|
||||
assert nameid.text != "foobar01"
|
||||
|
||||
|
||||
|
||||
def test_persistent_nameid(self):
|
||||
sp_id = "urn:mace:umu.se:sp"
|
||||
nameid = self.id.persistent_nameid("abcd0001", sp_id)
|
||||
remote_id = nameid.text.strip()
|
||||
print remote_id
|
||||
local = self.id.find_local_id(nameid)
|
||||
assert local == "abcd0001"
|
||||
|
||||
# Always get the same
|
||||
nameid2 = self.id.persistent_nameid("abcd0001", sp_id)
|
||||
assert nameid.text.strip() == nameid2.text.strip()
|
||||
|
||||
def test_transient_nameid(self):
|
||||
sp_id = "urn:mace:umu.se:sp"
|
||||
nameid = self.id.transient_nameid("abcd0001", sp_id)
|
||||
remote_id = nameid.text.strip()
|
||||
print remote_id
|
||||
local = self.id.find_local_id(nameid)
|
||||
assert local == "abcd0001"
|
||||
|
||||
# Getting a new, means really getting a new !
|
||||
nameid2 = self.id.transient_nameid(sp_id, "abcd0001")
|
||||
assert nameid.text.strip() != nameid2.text.strip()
|
||||
|
||||
def teardown_class(self):
|
||||
if os.path.exists("foobar.db"):
|
||||
os.unlink("foobar.db")
|
||||
|
||||
|
@ -5,7 +5,8 @@ from urlparse import parse_qs
|
||||
from saml2.saml import AUTHN_PASSWORD
|
||||
from saml2.samlp import response_from_string
|
||||
|
||||
from saml2.server import Server, Identifier
|
||||
from saml2.server import Server
|
||||
from saml2.ident import IdentDB
|
||||
from saml2 import samlp, saml, client, config
|
||||
from saml2 import s_utils
|
||||
from saml2 import sigver
|
||||
@ -21,45 +22,6 @@ import os
|
||||
def _eq(l1,l2):
|
||||
return set(l1) == set(l2)
|
||||
|
||||
class TestIdentifier():
|
||||
def setup_class(self):
|
||||
self.ident = Identifier("foobar.db")
|
||||
|
||||
def test_persistent_nameid(self):
|
||||
sp_id = "urn:mace:umu.se:sp"
|
||||
nameid = self.ident.persistent_nameid(sp_id, "abcd0001")
|
||||
remote_id = nameid.text.strip()
|
||||
print remote_id
|
||||
print self.ident.map
|
||||
local = self.ident.local_name(sp_id, remote_id)
|
||||
assert local == "abcd0001"
|
||||
assert self.ident.local_name(sp_id, "pseudo random string") is None
|
||||
assert self.ident.local_name(sp_id+":x", remote_id) is None
|
||||
|
||||
# Always get the same
|
||||
nameid2 = self.ident.persistent_nameid(sp_id, "abcd0001")
|
||||
assert nameid.text.strip() == nameid2.text.strip()
|
||||
|
||||
def test_transient_nameid(self):
|
||||
sp_id = "urn:mace:umu.se:sp"
|
||||
nameid = self.ident.transient_nameid(sp_id, "abcd0001")
|
||||
remote_id = nameid.text.strip()
|
||||
print remote_id
|
||||
print self.ident.map
|
||||
local = self.ident.local_name(sp_id, remote_id)
|
||||
assert local == "abcd0001"
|
||||
assert self.ident.local_name(sp_id, "pseudo random string") is None
|
||||
assert self.ident.local_name(sp_id+":x", remote_id) is None
|
||||
|
||||
# Getting a new, means really getting a new !
|
||||
nameid2 = self.ident.transient_nameid(sp_id, "abcd0001")
|
||||
assert nameid.text.strip() != nameid2.text.strip()
|
||||
|
||||
def teardown_class(self):
|
||||
if os.path.exists("foobar.db"):
|
||||
os.unlink("foobar.db")
|
||||
|
||||
|
||||
|
||||
class TestServer1():
|
||||
def setup_class(self):
|
||||
|
@ -406,8 +406,8 @@ class TestClientWithDummy():
|
||||
http_args["headers"] = [('Content-type','application/x-www-form-urlencoded')]
|
||||
|
||||
response = self.client.send(**http_args)
|
||||
|
||||
_dic = unpack_form(response["data"][3], "SAMLResponse")
|
||||
print response.text
|
||||
_dic = unpack_form(response.text[3], "SAMLResponse")
|
||||
resp = self.client.parse_authn_request_response(_dic["SAMLResponse"],
|
||||
BINDING_HTTP_POST,
|
||||
{id: "/"})
|
||||
|
@ -28,9 +28,7 @@ def get_msg(hinfo, binding, response=False):
|
||||
msg = hinfo["data"]
|
||||
else:
|
||||
msg = ""
|
||||
for header, val in hinfo["headers"]:
|
||||
if header == "Location":
|
||||
return parse_qs(val.split("?")[1])["ID"][0]
|
||||
return parse_qs(hinfo["url"].split("?")[1])["ID"][0]
|
||||
else: # BINDING_HTTP_REDIRECT
|
||||
parts = urlparse(hinfo["headers"][0][1])
|
||||
msg = parse_qs(parts.query)["SAMLRequest"][0]
|
||||
|
Loading…
Reference in New Issue
Block a user