Made the test work.

This commit is contained in:
Roland Hedberg 2013-01-18 11:53:31 +01:00
parent 3624c769a4
commit 21880631d9
11 changed files with 117 additions and 81 deletions

View File

@ -245,6 +245,11 @@ class Saml2Client(Base):
response = self.send_using_soap(query, destination)
if response:
if not response_args:
response_args = {"binding": BINDING_SOAP}
else:
response_args["binding"] = BINDING_SOAP
logger.info("Verifying response")
if response_args:
response = _response_func(response, **response_args)

View File

@ -364,24 +364,21 @@ class Base(Entity):
extensions=extensions,
sign=sign)
def create_assertion_id_request(self, assertion_id_refs, destination=None,
id=0, consent=None, extensions=None,
sign=False):
def create_assertion_id_request(self, assertion_id_refs, **kwargs):
"""
:param assertion_id_refs:
:param destination: The IdP endpoint to send the request to
:param id: Message identifier
:param consent: If the principal gave her consent to this request
:param extensions: Possible request extensions
:param sign: Whether the request should be signed or not.
:return: AssertionIDRequest instance
:return: One ID ref
"""
id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs]
return self._message(AssertionIDRequest, destination, id, consent,
extensions, sign, assertion_id_ref=id_refs )
# id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs]
#
# return self._message(AssertionIDRequest, destination, id, consent,
# extensions, sign, assertion_id_ref=id_refs )
if isinstance(assertion_id_refs, basestring):
return assertion_id_refs
else:
return assertion_id_refs[0]
def create_authn_query(self, subject, destination=None,
authn_context=None, session_index="",
@ -516,13 +513,16 @@ class Base(Entity):
res = self._parse_response(response, AssertionIDResponse, "", binding,
**kwargs)
return res
# ------------------------------------------------------------------------
def parse_attribute_query_response(self, response, binding):
kwargs = {"entity_id": self.config.entityid,
"attribute_converters": self.config.attribute_converters}
return self._parse_response(response, AttributeResponse,
"attribute_consuming_service", binding)
"attribute_consuming_service", binding,
**kwargs)
def parse_name_id_mapping_request_response(self, txt, binding=BINDING_SOAP):
"""
@ -531,4 +531,5 @@ class Base(Entity):
:param binding: Just a placeholder, it's always BINDING_SOAP
:return: parsed and verified <NameIDMappingResponse> instance
"""
return self._parse_response(txt, NameIDMappingResponse, "", binding)

View File

@ -45,6 +45,9 @@ PAIRS = {
class ConnectionError(Exception):
pass
class HTTPError(Exception):
pass
def _since_epoch(cdate):
"""
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
@ -217,9 +220,7 @@ class HTTPBase(object):
# msg should be an identifier
info = {
"data": "",
"headers": [
("Location", "%s?ID=%s" % (destination, message))
]
"url": "%s?ID=%s" % (destination, message)
}
else:
raise NotImplemented
@ -273,8 +274,10 @@ class HTTPBase(object):
raise
if response:
xmlstr = response.text
logger.info("SOAP response: %s" % xmlstr)
return response
if response.status_code == 200:
logger.info("SOAP response: %s" % response.text)
return response.text
else:
raise HTTPError("%d:%s" % (response.status_code, response.error))
else:
return None

View File

@ -16,7 +16,8 @@ __author__ = 'rolandh'
logger = logging.getLogger(__name__)
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id"]
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id",
"text"]
class Unknown(Exception):
pass
@ -168,9 +169,13 @@ class IdentDB(object):
return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT,
sp_name_qualifier, name_qualifier)
def permanent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier, name_qualifier)
def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier)
if nameid:
return nameid
else:
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier, name_qualifier)
def find_local_id(self, name_id):
"""
@ -191,8 +196,18 @@ class IdentDB(object):
nid = decode(val)
if nid.format == NAMEID_FORMAT_TRANSIENT:
continue
if getattr(nid, "sp_name_qualifier", "") == sp_name_qualifier:
if getattr(nid, "name_qualifier", "") == name_qualifier:
snq = getattr(nid, "sp_name_qualifier", "")
if snq and snq == sp_name_qualifier:
nq = getattr(nid, "name_qualifier", None)
if nq and nq == name_qualifier:
return nid
elif not nq and not name_qualifier:
return nid
elif not snq and not sp_name_qualifier:
nq = getattr(nid, "name_qualifier", None)
if nq and nq == name_qualifier:
return nid
elif not nq and not name_qualifier:
return nid
except KeyError:
pass
@ -279,4 +294,7 @@ class IdentDB(object):
try:
return self.db["%s:%s" % (userid, entity_id)]
except KeyError:
return None
return None
def close(self):
self.db.close()

View File

@ -106,7 +106,7 @@ class Server(Entity):
def close_shelve_db(self):
"""Close the shelve db to prevent file system locking issues"""
if self.ident:
self.ident.map.close()
self.ident.db.close()
def wants(self, sp_entity_id, index=None):
""" Returns what attributes the SP requires and which are optional

View File

@ -56,6 +56,11 @@ def parse_soap_enveloped_saml_attribute_query(text):
expected_tag = '{%s}AttributeQuery' % SAMLP_NAMESPACE
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
def parse_soap_enveloped_saml_attribute_response(text):
tags = ['{%s}Response' % SAMLP_NAMESPACE,
'{%s}AttributeResponse' % SAMLP_NAMESPACE]
return parse_soap_enveloped_saml_thingy(text, tags)
def parse_soap_enveloped_saml_logout_request(text):
expected_tag = '{%s}LogoutRequest' % SAMLP_NAMESPACE
return parse_soap_enveloped_saml_thingy(text, [expected_tag])

View File

@ -34,6 +34,12 @@ def unpack_form(_str, ver="SAMLRequest"):
return {ver:sr, "RelayState":rs}
class DummyResponse(object):
def __init__(self, code, data, headers=None):
self.status_code = code
self.text = data
self.headers = headers or []
class FakeIDP(Server):
def __init__(self, config_file=""):
Server.__init__(self, config_file)
@ -52,12 +58,19 @@ class FakeIDP(Server):
if method == "GET":
path, query = url.split("?")
qs_dict = parse_qs(kwargs["data"])
req = qs_dict["SAMLRequest"][0]
rstate = qs_dict["RelayState"][0]
else:
# Could be either POST or SOAP
path = url
qs_dict = parse_qs(kwargs["data"])
try:
qs_dict = parse_qs(kwargs["data"])
req = qs_dict["SAMLRequest"][0]
rstate = qs_dict["RelayState"][0]
except KeyError:
req = kwargs["data"]
rstate = ""
req = qs_dict["SAMLRequest"][0]
rstate = qs_dict["RelayState"][0]
response = ""
# Get service from path
@ -105,9 +118,10 @@ class FakeIDP(Server):
response = "%s" % authn_resp
return pack.factory(_binding, response,
_dict = pack.factory(_binding, response,
resp_args["destination"], relay_state,
"SAMLResponse")
return DummyResponse(200, **_dict)
def attribute_query_endpoint(self, xml_str, binding):
if binding == BINDING_SOAP:
@ -139,7 +153,7 @@ class FakeIDP(Server):
else: # Just POST
response = "%s" % attr_resp
return response
return DummyResponse(200, response)
def logout_endpoint(self, xml_str, binding):
if binding == BINDING_SOAP:
@ -164,4 +178,4 @@ class FakeIDP(Server):
else: # Just POST
response = "%s" % _resp
return response
return DummyResponse(200, response)

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python
import os
from saml2 import samlp
from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT
@ -142,4 +143,33 @@ class TestIdentifier():
assert nameid.sp_name_qualifier == 'http://vo.example.org/design'
assert nameid.format == NAMEID_FORMAT_PERSISTENT
assert nameid.text != "foobar01"
def test_persistent_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.id.persistent_nameid("abcd0001", sp_id)
remote_id = nameid.text.strip()
print remote_id
local = self.id.find_local_id(nameid)
assert local == "abcd0001"
# Always get the same
nameid2 = self.id.persistent_nameid("abcd0001", sp_id)
assert nameid.text.strip() == nameid2.text.strip()
def test_transient_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.id.transient_nameid("abcd0001", sp_id)
remote_id = nameid.text.strip()
print remote_id
local = self.id.find_local_id(nameid)
assert local == "abcd0001"
# Getting a new, means really getting a new !
nameid2 = self.id.transient_nameid(sp_id, "abcd0001")
assert nameid.text.strip() != nameid2.text.strip()
def teardown_class(self):
if os.path.exists("foobar.db"):
os.unlink("foobar.db")

View File

@ -5,7 +5,8 @@ from urlparse import parse_qs
from saml2.saml import AUTHN_PASSWORD
from saml2.samlp import response_from_string
from saml2.server import Server, Identifier
from saml2.server import Server
from saml2.ident import IdentDB
from saml2 import samlp, saml, client, config
from saml2 import s_utils
from saml2 import sigver
@ -21,45 +22,6 @@ import os
def _eq(l1,l2):
return set(l1) == set(l2)
class TestIdentifier():
def setup_class(self):
self.ident = Identifier("foobar.db")
def test_persistent_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.ident.persistent_nameid(sp_id, "abcd0001")
remote_id = nameid.text.strip()
print remote_id
print self.ident.map
local = self.ident.local_name(sp_id, remote_id)
assert local == "abcd0001"
assert self.ident.local_name(sp_id, "pseudo random string") is None
assert self.ident.local_name(sp_id+":x", remote_id) is None
# Always get the same
nameid2 = self.ident.persistent_nameid(sp_id, "abcd0001")
assert nameid.text.strip() == nameid2.text.strip()
def test_transient_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.ident.transient_nameid(sp_id, "abcd0001")
remote_id = nameid.text.strip()
print remote_id
print self.ident.map
local = self.ident.local_name(sp_id, remote_id)
assert local == "abcd0001"
assert self.ident.local_name(sp_id, "pseudo random string") is None
assert self.ident.local_name(sp_id+":x", remote_id) is None
# Getting a new, means really getting a new !
nameid2 = self.ident.transient_nameid(sp_id, "abcd0001")
assert nameid.text.strip() != nameid2.text.strip()
def teardown_class(self):
if os.path.exists("foobar.db"):
os.unlink("foobar.db")
class TestServer1():
def setup_class(self):

View File

@ -406,8 +406,8 @@ class TestClientWithDummy():
http_args["headers"] = [('Content-type','application/x-www-form-urlencoded')]
response = self.client.send(**http_args)
_dic = unpack_form(response["data"][3], "SAMLResponse")
print response.text
_dic = unpack_form(response.text[3], "SAMLResponse")
resp = self.client.parse_authn_request_response(_dic["SAMLResponse"],
BINDING_HTTP_POST,
{id: "/"})

View File

@ -28,9 +28,7 @@ def get_msg(hinfo, binding, response=False):
msg = hinfo["data"]
else:
msg = ""
for header, val in hinfo["headers"]:
if header == "Location":
return parse_qs(val.split("?")[1])["ID"][0]
return parse_qs(hinfo["url"].split("?")[1])["ID"][0]
else: # BINDING_HTTP_REDIRECT
parts = urlparse(hinfo["headers"][0][1])
msg = parse_qs(parts.query)["SAMLRequest"][0]