Made the test work.

This commit is contained in:
Roland Hedberg
2013-01-18 11:53:31 +01:00
parent 3624c769a4
commit 21880631d9
11 changed files with 117 additions and 81 deletions

View File

@@ -245,6 +245,11 @@ class Saml2Client(Base):
response = self.send_using_soap(query, destination) response = self.send_using_soap(query, destination)
if response: if response:
if not response_args:
response_args = {"binding": BINDING_SOAP}
else:
response_args["binding"] = BINDING_SOAP
logger.info("Verifying response") logger.info("Verifying response")
if response_args: if response_args:
response = _response_func(response, **response_args) response = _response_func(response, **response_args)

View File

@@ -364,24 +364,21 @@ class Base(Entity):
extensions=extensions, extensions=extensions,
sign=sign) sign=sign)
def create_assertion_id_request(self, assertion_id_refs, destination=None, def create_assertion_id_request(self, assertion_id_refs, **kwargs):
id=0, consent=None, extensions=None,
sign=False):
""" """
:param assertion_id_refs: :param assertion_id_refs:
:param destination: The IdP endpoint to send the request to :return: One ID ref
:param id: Message identifier
:param consent: If the principal gave her consent to this request
:param extensions: Possible request extensions
:param sign: Whether the request should be signed or not.
:return: AssertionIDRequest instance
""" """
id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs] # id_refs = [AssertionIDRef(text=s) for s in assertion_id_refs]
#
return self._message(AssertionIDRequest, destination, id, consent, # return self._message(AssertionIDRequest, destination, id, consent,
extensions, sign, assertion_id_ref=id_refs ) # extensions, sign, assertion_id_ref=id_refs )
if isinstance(assertion_id_refs, basestring):
return assertion_id_refs
else:
return assertion_id_refs[0]
def create_authn_query(self, subject, destination=None, def create_authn_query(self, subject, destination=None,
authn_context=None, session_index="", authn_context=None, session_index="",
@@ -516,13 +513,16 @@ class Base(Entity):
res = self._parse_response(response, AssertionIDResponse, "", binding, res = self._parse_response(response, AssertionIDResponse, "", binding,
**kwargs) **kwargs)
return res return res
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def parse_attribute_query_response(self, response, binding): def parse_attribute_query_response(self, response, binding):
kwargs = {"entity_id": self.config.entityid,
"attribute_converters": self.config.attribute_converters}
return self._parse_response(response, AttributeResponse, return self._parse_response(response, AttributeResponse,
"attribute_consuming_service", binding) "attribute_consuming_service", binding,
**kwargs)
def parse_name_id_mapping_request_response(self, txt, binding=BINDING_SOAP): def parse_name_id_mapping_request_response(self, txt, binding=BINDING_SOAP):
""" """
@@ -531,4 +531,5 @@ class Base(Entity):
:param binding: Just a placeholder, it's always BINDING_SOAP :param binding: Just a placeholder, it's always BINDING_SOAP
:return: parsed and verified <NameIDMappingResponse> instance :return: parsed and verified <NameIDMappingResponse> instance
""" """
return self._parse_response(txt, NameIDMappingResponse, "", binding) return self._parse_response(txt, NameIDMappingResponse, "", binding)

View File

@@ -45,6 +45,9 @@ PAIRS = {
class ConnectionError(Exception): class ConnectionError(Exception):
pass pass
class HTTPError(Exception):
pass
def _since_epoch(cdate): def _since_epoch(cdate):
""" """
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT' :param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
@@ -217,9 +220,7 @@ class HTTPBase(object):
# msg should be an identifier # msg should be an identifier
info = { info = {
"data": "", "data": "",
"headers": [ "url": "%s?ID=%s" % (destination, message)
("Location", "%s?ID=%s" % (destination, message))
]
} }
else: else:
raise NotImplemented raise NotImplemented
@@ -273,8 +274,10 @@ class HTTPBase(object):
raise raise
if response: if response:
xmlstr = response.text if response.status_code == 200:
logger.info("SOAP response: %s" % xmlstr) logger.info("SOAP response: %s" % response.text)
return response return response.text
else:
raise HTTPError("%d:%s" % (response.status_code, response.error))
else: else:
return None return None

View File

@@ -16,7 +16,8 @@ __author__ = 'rolandh'
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id"] ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id",
"text"]
class Unknown(Exception): class Unknown(Exception):
pass pass
@@ -168,9 +169,13 @@ class IdentDB(object):
return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT, return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT,
sp_name_qualifier, name_qualifier) sp_name_qualifier, name_qualifier)
def permanent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""): def persistent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""):
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT, nameid = self.match_local_id(userid, sp_name_qualifier, name_qualifier)
sp_name_qualifier, name_qualifier) if nameid:
return nameid
else:
return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier, name_qualifier)
def find_local_id(self, name_id): def find_local_id(self, name_id):
""" """
@@ -191,8 +196,18 @@ class IdentDB(object):
nid = decode(val) nid = decode(val)
if nid.format == NAMEID_FORMAT_TRANSIENT: if nid.format == NAMEID_FORMAT_TRANSIENT:
continue continue
if getattr(nid, "sp_name_qualifier", "") == sp_name_qualifier: snq = getattr(nid, "sp_name_qualifier", "")
if getattr(nid, "name_qualifier", "") == name_qualifier: if snq and snq == sp_name_qualifier:
nq = getattr(nid, "name_qualifier", None)
if nq and nq == name_qualifier:
return nid
elif not nq and not name_qualifier:
return nid
elif not snq and not sp_name_qualifier:
nq = getattr(nid, "name_qualifier", None)
if nq and nq == name_qualifier:
return nid
elif not nq and not name_qualifier:
return nid return nid
except KeyError: except KeyError:
pass pass
@@ -279,4 +294,7 @@ class IdentDB(object):
try: try:
return self.db["%s:%s" % (userid, entity_id)] return self.db["%s:%s" % (userid, entity_id)]
except KeyError: except KeyError:
return None return None
def close(self):
self.db.close()

View File

@@ -106,7 +106,7 @@ class Server(Entity):
def close_shelve_db(self): def close_shelve_db(self):
"""Close the shelve db to prevent file system locking issues""" """Close the shelve db to prevent file system locking issues"""
if self.ident: if self.ident:
self.ident.map.close() self.ident.db.close()
def wants(self, sp_entity_id, index=None): def wants(self, sp_entity_id, index=None):
""" Returns what attributes the SP requires and which are optional """ Returns what attributes the SP requires and which are optional

View File

@@ -56,6 +56,11 @@ def parse_soap_enveloped_saml_attribute_query(text):
expected_tag = '{%s}AttributeQuery' % SAMLP_NAMESPACE expected_tag = '{%s}AttributeQuery' % SAMLP_NAMESPACE
return parse_soap_enveloped_saml_thingy(text, [expected_tag]) return parse_soap_enveloped_saml_thingy(text, [expected_tag])
def parse_soap_enveloped_saml_attribute_response(text):
tags = ['{%s}Response' % SAMLP_NAMESPACE,
'{%s}AttributeResponse' % SAMLP_NAMESPACE]
return parse_soap_enveloped_saml_thingy(text, tags)
def parse_soap_enveloped_saml_logout_request(text): def parse_soap_enveloped_saml_logout_request(text):
expected_tag = '{%s}LogoutRequest' % SAMLP_NAMESPACE expected_tag = '{%s}LogoutRequest' % SAMLP_NAMESPACE
return parse_soap_enveloped_saml_thingy(text, [expected_tag]) return parse_soap_enveloped_saml_thingy(text, [expected_tag])

View File

@@ -34,6 +34,12 @@ def unpack_form(_str, ver="SAMLRequest"):
return {ver:sr, "RelayState":rs} return {ver:sr, "RelayState":rs}
class DummyResponse(object):
def __init__(self, code, data, headers=None):
self.status_code = code
self.text = data
self.headers = headers or []
class FakeIDP(Server): class FakeIDP(Server):
def __init__(self, config_file=""): def __init__(self, config_file=""):
Server.__init__(self, config_file) Server.__init__(self, config_file)
@@ -52,12 +58,19 @@ class FakeIDP(Server):
if method == "GET": if method == "GET":
path, query = url.split("?") path, query = url.split("?")
qs_dict = parse_qs(kwargs["data"]) qs_dict = parse_qs(kwargs["data"])
req = qs_dict["SAMLRequest"][0]
rstate = qs_dict["RelayState"][0]
else: else:
# Could be either POST or SOAP
path = url path = url
qs_dict = parse_qs(kwargs["data"]) try:
qs_dict = parse_qs(kwargs["data"])
req = qs_dict["SAMLRequest"][0]
rstate = qs_dict["RelayState"][0]
except KeyError:
req = kwargs["data"]
rstate = ""
req = qs_dict["SAMLRequest"][0]
rstate = qs_dict["RelayState"][0]
response = "" response = ""
# Get service from path # Get service from path
@@ -105,9 +118,10 @@ class FakeIDP(Server):
response = "%s" % authn_resp response = "%s" % authn_resp
return pack.factory(_binding, response, _dict = pack.factory(_binding, response,
resp_args["destination"], relay_state, resp_args["destination"], relay_state,
"SAMLResponse") "SAMLResponse")
return DummyResponse(200, **_dict)
def attribute_query_endpoint(self, xml_str, binding): def attribute_query_endpoint(self, xml_str, binding):
if binding == BINDING_SOAP: if binding == BINDING_SOAP:
@@ -139,7 +153,7 @@ class FakeIDP(Server):
else: # Just POST else: # Just POST
response = "%s" % attr_resp response = "%s" % attr_resp
return response return DummyResponse(200, response)
def logout_endpoint(self, xml_str, binding): def logout_endpoint(self, xml_str, binding):
if binding == BINDING_SOAP: if binding == BINDING_SOAP:
@@ -164,4 +178,4 @@ class FakeIDP(Server):
else: # Just POST else: # Just POST
response = "%s" % _resp response = "%s" % _resp
return response return DummyResponse(200, response)

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import os
from saml2 import samlp from saml2 import samlp
from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT
@@ -142,4 +143,33 @@ class TestIdentifier():
assert nameid.sp_name_qualifier == 'http://vo.example.org/design' assert nameid.sp_name_qualifier == 'http://vo.example.org/design'
assert nameid.format == NAMEID_FORMAT_PERSISTENT assert nameid.format == NAMEID_FORMAT_PERSISTENT
assert nameid.text != "foobar01" assert nameid.text != "foobar01"
def test_persistent_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.id.persistent_nameid("abcd0001", sp_id)
remote_id = nameid.text.strip()
print remote_id
local = self.id.find_local_id(nameid)
assert local == "abcd0001"
# Always get the same
nameid2 = self.id.persistent_nameid("abcd0001", sp_id)
assert nameid.text.strip() == nameid2.text.strip()
def test_transient_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.id.transient_nameid("abcd0001", sp_id)
remote_id = nameid.text.strip()
print remote_id
local = self.id.find_local_id(nameid)
assert local == "abcd0001"
# Getting a new, means really getting a new !
nameid2 = self.id.transient_nameid(sp_id, "abcd0001")
assert nameid.text.strip() != nameid2.text.strip()
def teardown_class(self):
if os.path.exists("foobar.db"):
os.unlink("foobar.db")

View File

@@ -5,7 +5,8 @@ from urlparse import parse_qs
from saml2.saml import AUTHN_PASSWORD from saml2.saml import AUTHN_PASSWORD
from saml2.samlp import response_from_string from saml2.samlp import response_from_string
from saml2.server import Server, Identifier from saml2.server import Server
from saml2.ident import IdentDB
from saml2 import samlp, saml, client, config from saml2 import samlp, saml, client, config
from saml2 import s_utils from saml2 import s_utils
from saml2 import sigver from saml2 import sigver
@@ -21,45 +22,6 @@ import os
def _eq(l1,l2): def _eq(l1,l2):
return set(l1) == set(l2) return set(l1) == set(l2)
class TestIdentifier():
def setup_class(self):
self.ident = Identifier("foobar.db")
def test_persistent_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.ident.persistent_nameid(sp_id, "abcd0001")
remote_id = nameid.text.strip()
print remote_id
print self.ident.map
local = self.ident.local_name(sp_id, remote_id)
assert local == "abcd0001"
assert self.ident.local_name(sp_id, "pseudo random string") is None
assert self.ident.local_name(sp_id+":x", remote_id) is None
# Always get the same
nameid2 = self.ident.persistent_nameid(sp_id, "abcd0001")
assert nameid.text.strip() == nameid2.text.strip()
def test_transient_nameid(self):
sp_id = "urn:mace:umu.se:sp"
nameid = self.ident.transient_nameid(sp_id, "abcd0001")
remote_id = nameid.text.strip()
print remote_id
print self.ident.map
local = self.ident.local_name(sp_id, remote_id)
assert local == "abcd0001"
assert self.ident.local_name(sp_id, "pseudo random string") is None
assert self.ident.local_name(sp_id+":x", remote_id) is None
# Getting a new, means really getting a new !
nameid2 = self.ident.transient_nameid(sp_id, "abcd0001")
assert nameid.text.strip() != nameid2.text.strip()
def teardown_class(self):
if os.path.exists("foobar.db"):
os.unlink("foobar.db")
class TestServer1(): class TestServer1():
def setup_class(self): def setup_class(self):

View File

@@ -406,8 +406,8 @@ class TestClientWithDummy():
http_args["headers"] = [('Content-type','application/x-www-form-urlencoded')] http_args["headers"] = [('Content-type','application/x-www-form-urlencoded')]
response = self.client.send(**http_args) response = self.client.send(**http_args)
print response.text
_dic = unpack_form(response["data"][3], "SAMLResponse") _dic = unpack_form(response.text[3], "SAMLResponse")
resp = self.client.parse_authn_request_response(_dic["SAMLResponse"], resp = self.client.parse_authn_request_response(_dic["SAMLResponse"],
BINDING_HTTP_POST, BINDING_HTTP_POST,
{id: "/"}) {id: "/"})

View File

@@ -28,9 +28,7 @@ def get_msg(hinfo, binding, response=False):
msg = hinfo["data"] msg = hinfo["data"]
else: else:
msg = "" msg = ""
for header, val in hinfo["headers"]: return parse_qs(hinfo["url"].split("?")[1])["ID"][0]
if header == "Location":
return parse_qs(val.split("?")[1])["ID"][0]
else: # BINDING_HTTP_REDIRECT else: # BINDING_HTTP_REDIRECT
parts = urlparse(hinfo["headers"][0][1]) parts = urlparse(hinfo["headers"][0][1])
msg = parse_qs(parts.query)["SAMLRequest"][0] msg = parse_qs(parts.query)["SAMLRequest"][0]