diff --git a/.idea/libraries/sass_stdlib.xml b/.idea/libraries/sass_stdlib.xml new file mode 100644 index 0000000..546bfd1 --- /dev/null +++ b/.idea/libraries/sass_stdlib.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/src/s2repoze/plugins/sp.py b/src/s2repoze/plugins/sp.py index e7b0eac..1558340 100644 --- a/src/s2repoze/plugins/sp.py +++ b/src/s2repoze/plugins/sp.py @@ -40,6 +40,8 @@ from repoze.who.plugins.form import FormPluginBase from saml2 import ecp from saml2.client import Saml2Client +from saml2.discovery import discovery_service_response +from saml2.discovery import discovery_service_request_url from saml2.s_utils import sid from saml2.config import config_factory from saml2.profile import paos @@ -111,6 +113,8 @@ class ECP_response(object): [('Content-Type', "text/xml")]) return [self.content] + + class SAML2Plugin(FormPluginBase): implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider) @@ -242,14 +246,14 @@ class SAML2Plugin(FormPluginBase): return self._wayf_redirect(came_from) elif self.discovery: if query: - idp_entity_id = self.saml_client.get_idp_from_discovery_service( + idp_entity_id = discovery_service_response( query=environ.get("QUERY_STRING")) else: sid_ = sid() self.outstanding_queries[sid_] = came_from logger.info("Redirect to Discovery Service function") - loc = self.saml_client.request_to_discovery_service( - self.discovery) + eid = self.saml_client.config.entity_id + loc = discovery_service_request_url(eid, self.discovery) return -1, HTTPSeeOther(headers = [('Location',loc)]) else: return -1, HTTPNotImplemented(detail='No WAYF or DJ present!') @@ -273,13 +277,13 @@ class SAML2Plugin(FormPluginBase): environ["myapp.came_from"] = came_from logger.debug("[sp.challenge] RelayState >> %s" % came_from) - # Am I part of a virtual organization ? + # Am I part of a virtual organization or more than one ? try: vorg_name = environ["myapp.vo"] except KeyError: try: - vorg_name = self.saml_client.vorg.vorg_name - except AttributeError: + vorg_name = self.saml_client.vorg.keys()[1] + except IndexError: vorg_name = "" logger.info("[sp.challenge] VO: %s" % vorg_name) @@ -300,9 +304,9 @@ class SAML2Plugin(FormPluginBase): logger.info("[sp.challenge] idp_url: %s" % idp_url) # Do the AuthnRequest - (sid_, result) = self.saml_client.authenticate(idp_url, - relay_state=came_from, - vorg=vorg_name) + sid_, result = self.saml_client.do_authenticate(idp_url, + relay_state=came_from, + vorg=vorg_name) # remember the request self.outstanding_queries[sid_] = came_from @@ -466,9 +470,9 @@ class SAML2Plugin(FormPluginBase): if "pysaml2_vo_expanded" not in identity: # is this a Virtual Organization situation - if self.saml_client.vorg: + for vo in self.saml_client.vorg.values(): try: - if self.saml_client.vorg.do_aggregation(subject_id): + if vo.do_aggregation(subject_id): # Get the extended identity identity["user"] = self.saml_client.users.get_identity( subject_id)[0] diff --git a/src/saml2/client.py b/src/saml2/client.py index 789ae01..4e98236 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -90,7 +90,7 @@ class Saml2Client(Base): else: raise Exception("Unknown binding type: %s" % binding) - return response + return req.id, response def global_logout(self, subject_id, reason="", expire=None, sign=None, return_to="/"): diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index c66a14e..9c1dc12 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -83,13 +83,12 @@ class Base(object): """ The basic pySAML2 service provider class """ def __init__(self, config=None, identity_cache=None, state_cache=None, - virtual_organization=None, config_file=""): + virtual_organization="",config_file=""): """ :param config: A saml2.config.Config instance :param identity_cache: Where the class should store identity information :param state_cache: Where the class should keep state information - :param virtual_organization: Which if any virtual organization this - SP belongs to + :param virtual_organization: A specific virtual organization """ self.users = Population(identity_cache) @@ -107,6 +106,10 @@ class Base(object): else: raise Exception("Missing configuration") + if self.config.vorg: + for vo in self.config.vorg.values(): + vo.sp = self + self.metadata = self.config.metadata self.config.setup_logger() @@ -118,7 +121,10 @@ class Base(object): self.sec = security_context(self.config) if virtual_organization: - self.vorg = VirtualOrg(self, virtual_organization) + if isinstance(virtual_organization, basestring): + self.vorg = self.config.vorg[virtual_organization] + elif isinstance(virtual_organization, VirtualOrg): + self.vorg = virtual_organization else: self.vorg = None diff --git a/src/saml2/config.py b/src/saml2/config.py index a030689..bedead5 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from saml2.virtual_org import VirtualOrg __author__ = 'rolandh' @@ -113,7 +114,7 @@ class Config(object): self.name_form=None self.virtual_organization=None self.logger=None - self.only_use_keys_in_metadata=None + self.only_use_keys_in_metadata=True self.logout_requests_signed=None self.disable_ssl_certificate_validation=None self.context = "" @@ -121,6 +122,7 @@ class Config(object): self.metadata=None self.policy=None self.serves = [] + self.vorg = {} # def copy_into(self, typ=""): # if typ == "sp": @@ -200,6 +202,13 @@ class Config(object): :return: The Configuration instance """ for arg in COMMON_ARGS: + if arg == "virtual_organization": + if "virtual_organization" in cnf: + for key,val in cnf["virtual_organization"].items(): + self.vorg[key] = VirtualOrg(None, key, val) + continue + + try: setattr(self, arg, cnf[arg]) except KeyError: diff --git a/src/saml2/response.py b/src/saml2/response.py index c6710a9..7445bb5 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -269,7 +269,7 @@ class AuthnResponse(StatusResponse): elif self.allow_unsolicited: pass else: - logger("Unsolicited response") + logger.exception("Unsolicited response") raise Exception("Unsolicited response") return self diff --git a/src/saml2/server.py b/src/saml2/server.py index 739588d..448095a 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -110,22 +110,18 @@ class Identifier(object): return temp_id - def _get_vo_identifier(self, sp_name_qualifier, userid, identity): + def _get_vo_identifier(self, sp_name_qualifier, identity): try: - vo_conf = self.voconf[sp_name_qualifier] - if "common_identifier" in vo_conf: - try: - subj_id = identity[vo_conf["common_identifier"]] - except KeyError: - raise MissingValue("Common identifier") - else: - return self.persistent_nameid(sp_name_qualifier, userid) + vo = self.voconf[sp_name_qualifier] + try: + subj_id = identity[vo.common_identifier] + except KeyError: + raise MissingValue("Common identifier") except (KeyError, TypeError): raise UnknownVO("%s" % sp_name_qualifier) - try: - nameid_format = vo_conf["nameid_format"] - except KeyError: + nameid_format = vo.nameid_format + if not nameid_format: nameid_format = saml.NAMEID_FORMAT_PERSISTENT return saml.NameID(format=nameid_format, @@ -189,9 +185,9 @@ class Identifier(object): if name_id_policy and name_id_policy.sp_name_qualifier: try: return self._get_vo_identifier(name_id_policy.sp_name_qualifier, - userid, identity) - except Exception: - pass + identity) + except Exception, exc: + print >> sys.stderr, "%s:%s" % (exc.__class__.__name__, exc) if sp_nid: nameid_format = sp_nid[0] diff --git a/src/saml2/virtual_org.py b/src/saml2/virtual_org.py index c901b14..9369e60 100644 --- a/src/saml2/virtual_org.py +++ b/src/saml2/virtual_org.py @@ -1,15 +1,23 @@ import logging from saml2.attribute_resolver import AttributeResolver +from saml2.saml import NAMEID_FORMAT_PERSISTENT logger = logging.getLogger(__name__) class VirtualOrg(object): - def __init__(self, sp, vorg): - self.sp = sp # The parent SP client instance - self.config = sp.config - self.vorg_name = vorg - self.vorg_conf = self.config.vo_conf(self.vorg_name) - + def __init__(self, sp, vorg, cnf): + self.sp = sp # The parent SP client instance + self._name = vorg + self.common_identifier = cnf["common_identifier"] + try: + self.member = cnf["member"] + except KeyError: + self.member = [] + try: + self.nameid_format = cnf["nameid_format"] + except KeyError: + self.nameid_format = NAMEID_FORMAT_PERSISTENT + def _cache_session(self, session_info): return True @@ -18,25 +26,15 @@ class VirtualOrg(object): Get the member of the Virtual Organization from the metadata, more specifically from AffiliationDescriptor. """ - return self.config.metadata.vo_members(self.vorg_name) - - def _vo_conf_members(self): - """ - Get the member of the Virtual Organization from the configuration. - """ - - try: - return self.vorg_conf["member"] - except (KeyError, TypeError): - return [] - + return self.sp.config.metadata.vo_members(self._name) + def members_to_ask(self, subject_id): """Find the member of the Virtual Organization that I haven't already spoken too """ vo_members = self._affiliation_members() - for member in self._vo_conf_members(): + for member in self.member: if member not in vo_members: vo_members.append(member) @@ -51,7 +49,7 @@ class VirtualOrg(object): if ava == {}: return None - ident = self.vorg_conf["common_identifier"] + ident = self.common_identifier try: return ava[ident][0] @@ -61,16 +59,16 @@ class VirtualOrg(object): def do_aggregation(self, subject_id): logger.info("** Do VO aggregation **\nSubjectID: %s, VO:%s" % ( - subject_id, self.vorg_name)) + subject_id, self._name)) to_ask = self.members_to_ask(subject_id) if to_ask: # Find the NameIDFormat and the SPNameQualifier - if self.vorg_conf and "nameid_format" in self.vorg_conf: - name_id_format = self.vorg_conf["nameid_format"] + if self.nameid_format: + name_id_format = self.nameid_format sp_name_qualifier = "" else: - sp_name_qualifier = self.vorg_name + sp_name_qualifier = self._name name_id_format = "" com_identifier = self.get_common_identifier(subject_id) diff --git a/tests/test_31_config.py b/tests/test_31_config.py index 15624d7..de67c4b 100644 --- a/tests/test_31_config.py +++ b/tests/test_31_config.py @@ -192,7 +192,7 @@ def test_2(): assert len(c._sp_idp) == 1 assert c._sp_idp.keys() == [""] assert c._sp_idp.values() == ["https://example.com/saml2/idp/SSOService.php"] - assert c.only_use_keys_in_metadata is None + assert c.only_use_keys_in_metadata is True def test_minimum(): minimum = { diff --git a/tests/test_33_identifier.py b/tests/test_33_identifier.py index 6082528..d8a266f 100644 --- a/tests/test_33_identifier.py +++ b/tests/test_33_identifier.py @@ -6,7 +6,6 @@ from saml2.config import IdPConfig from saml2.server import Identifier from saml2.assertion import Policy - def _eq(l1,l2): return set(l1) == set(l2) @@ -34,6 +33,8 @@ CONFIG = IdPConfig().load({ "common_identifier": "uid", }, "http://vo.example.org/design":{ + "nameid_format" : NAMEID_FORMAT_PERSISTENT, + "common_identifier": "uid", } } }) @@ -53,7 +54,7 @@ NAME_ID_POLICY_2 = """ class TestIdentifier(): def setup_class(self): - self.id = Identifier("subject.db", CONFIG.virtual_organization) + self.id = Identifier("subject.db", CONFIG.vorg) def test_persistent_1(self): policy = Policy({ @@ -109,16 +110,17 @@ class TestIdentifier(): }) name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_1) + print name_id_policy + print self.id.voconf nameid = self.id.construct_nameid(policy, "foobar", "urn:mace:example.com:sp:1", {"uid": "foobar01"}, name_id_policy) - + + print nameid assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format']) assert nameid.sp_name_qualifier == 'http://vo.example.org/biomed' - assert nameid.format == \ - CONFIG.virtual_organization['http://vo.example.org/biomed'][ - "nameid_format"] + assert nameid.format == 'urn:oid:2.16.756.1.2.5.1.1.1-NameID' assert nameid.text == "foobar01" def test_vo_2(self): @@ -142,5 +144,5 @@ class TestIdentifier(): assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format']) assert nameid.sp_name_qualifier == 'http://vo.example.org/design' assert nameid.format == NAMEID_FORMAT_PERSISTENT - assert nameid.text != "foobar01" + assert nameid.text == "foobar01" diff --git a/tests/test_44_authnresp.py b/tests/test_44_authnresp.py index dc2db5e..e71195a 100644 --- a/tests/test_44_authnresp.py +++ b/tests/test_44_authnresp.py @@ -45,6 +45,7 @@ class TestAuthnResponse: policy=policy) self.conf = config_factory("sp", "server_conf") + self.conf.only_use_keys_in_metadata = False self.ar = authn_response(self.conf, "http://lingon.catalogix.se:8087/") def test_verify_1(self): diff --git a/tests/test_51_client.py b/tests/test_51_client.py index 4d86d15..f778ec2 100644 --- a/tests/test_51_client.py +++ b/tests/test_51_client.py @@ -201,7 +201,7 @@ class TestClient: assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT def test_create_auth_request_vo(self): - assert self.client.config.virtual_organization.keys() == [ + assert self.client.config.vorg.keys() == [ "urn:mace:example.com:it:tek"] ar_str = "%s" % self.client.create_authn_request( @@ -337,7 +337,7 @@ class TestClient: def test_authenticate(self): print self.client.config.idps() - response = self.client.do_authenticate( + id, response = self.client.do_authenticate( "urn:mace:example.com:saml:roland:idp", "http://www.example.com/relay_state") assert response[0] == "Location" @@ -349,7 +349,7 @@ class TestClient: authnreq = samlp.authn_request_from_string(saml_request) def test_authenticate_no_args(self): - response = self.client.do_authenticate(relay_state="http://www.example.com/relay_state") + id, response = self.client.do_authenticate(relay_state="http://www.example.com/relay_state") assert response[0] == "Location" o = urlparse(response[1]) qdict = parse_qs(o.query) diff --git a/tests/test_62_vo.py b/tests/test_62_vo.py index 43c3c2d..6c8c398 100644 --- a/tests/test_62_vo.py +++ b/tests/test_62_vo.py @@ -24,8 +24,8 @@ class TestVirtualOrg(): conf.load_file("server_conf") self.sp = Saml2Client(conf) - vo_name = conf.virtual_organization.keys()[0] - self.vo = VirtualOrg(self.sp, vo_name) + vo_name = conf.vorg.keys()[0] + self.vo = conf.vorg[vo_name] add_derek_info(self.sp) def test_mta(self): @@ -53,7 +53,7 @@ class TestVirtualOrg_2(): def setup_class(self): conf = config.SPConfig() conf.load_file("server_conf") - vo_name = conf.virtual_organization.keys()[0] + vo_name = conf.vorg.keys()[0] self.sp = Saml2Client(conf, virtual_organization=vo_name) add_derek_info(self.sp)