This commit is contained in:
Roland Hedberg
2012-11-14 17:34:24 +01:00
parent 74cf8659e1
commit 3db08b124e
13 changed files with 95 additions and 71 deletions

8
.idea/libraries/sass_stdlib.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<component name="libraryTable">
<library name="sass-stdlib">
<CLASSES />
<SOURCES>
<root url="file://$APPLICATION_HOME_DIR$/plugins/sass/lib/stubs/sass_functions.scss" />
</SOURCES>
</library>
</component>

View File

@@ -40,6 +40,8 @@ from repoze.who.plugins.form import FormPluginBase
from saml2 import ecp from saml2 import ecp
from saml2.client import Saml2Client from saml2.client import Saml2Client
from saml2.discovery import discovery_service_response
from saml2.discovery import discovery_service_request_url
from saml2.s_utils import sid from saml2.s_utils import sid
from saml2.config import config_factory from saml2.config import config_factory
from saml2.profile import paos from saml2.profile import paos
@@ -111,6 +113,8 @@ class ECP_response(object):
[('Content-Type', "text/xml")]) [('Content-Type', "text/xml")])
return [self.content] return [self.content]
class SAML2Plugin(FormPluginBase): class SAML2Plugin(FormPluginBase):
implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider) implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider)
@@ -242,14 +246,14 @@ class SAML2Plugin(FormPluginBase):
return self._wayf_redirect(came_from) return self._wayf_redirect(came_from)
elif self.discovery: elif self.discovery:
if query: if query:
idp_entity_id = self.saml_client.get_idp_from_discovery_service( idp_entity_id = discovery_service_response(
query=environ.get("QUERY_STRING")) query=environ.get("QUERY_STRING"))
else: else:
sid_ = sid() sid_ = sid()
self.outstanding_queries[sid_] = came_from self.outstanding_queries[sid_] = came_from
logger.info("Redirect to Discovery Service function") logger.info("Redirect to Discovery Service function")
loc = self.saml_client.request_to_discovery_service( eid = self.saml_client.config.entity_id
self.discovery) loc = discovery_service_request_url(eid, self.discovery)
return -1, HTTPSeeOther(headers = [('Location',loc)]) return -1, HTTPSeeOther(headers = [('Location',loc)])
else: else:
return -1, HTTPNotImplemented(detail='No WAYF or DJ present!') return -1, HTTPNotImplemented(detail='No WAYF or DJ present!')
@@ -273,13 +277,13 @@ class SAML2Plugin(FormPluginBase):
environ["myapp.came_from"] = came_from environ["myapp.came_from"] = came_from
logger.debug("[sp.challenge] RelayState >> %s" % came_from) logger.debug("[sp.challenge] RelayState >> %s" % came_from)
# Am I part of a virtual organization ? # Am I part of a virtual organization or more than one ?
try: try:
vorg_name = environ["myapp.vo"] vorg_name = environ["myapp.vo"]
except KeyError: except KeyError:
try: try:
vorg_name = self.saml_client.vorg.vorg_name vorg_name = self.saml_client.vorg.keys()[1]
except AttributeError: except IndexError:
vorg_name = "" vorg_name = ""
logger.info("[sp.challenge] VO: %s" % vorg_name) logger.info("[sp.challenge] VO: %s" % vorg_name)
@@ -300,9 +304,9 @@ class SAML2Plugin(FormPluginBase):
logger.info("[sp.challenge] idp_url: %s" % idp_url) logger.info("[sp.challenge] idp_url: %s" % idp_url)
# Do the AuthnRequest # Do the AuthnRequest
(sid_, result) = self.saml_client.authenticate(idp_url, sid_, result = self.saml_client.do_authenticate(idp_url,
relay_state=came_from, relay_state=came_from,
vorg=vorg_name) vorg=vorg_name)
# remember the request # remember the request
self.outstanding_queries[sid_] = came_from self.outstanding_queries[sid_] = came_from
@@ -466,9 +470,9 @@ class SAML2Plugin(FormPluginBase):
if "pysaml2_vo_expanded" not in identity: if "pysaml2_vo_expanded" not in identity:
# is this a Virtual Organization situation # is this a Virtual Organization situation
if self.saml_client.vorg: for vo in self.saml_client.vorg.values():
try: try:
if self.saml_client.vorg.do_aggregation(subject_id): if vo.do_aggregation(subject_id):
# Get the extended identity # Get the extended identity
identity["user"] = self.saml_client.users.get_identity( identity["user"] = self.saml_client.users.get_identity(
subject_id)[0] subject_id)[0]

View File

@@ -90,7 +90,7 @@ class Saml2Client(Base):
else: else:
raise Exception("Unknown binding type: %s" % binding) raise Exception("Unknown binding type: %s" % binding)
return response return req.id, response
def global_logout(self, subject_id, reason="", expire=None, sign=None, def global_logout(self, subject_id, reason="", expire=None, sign=None,
return_to="/"): return_to="/"):

View File

@@ -83,13 +83,12 @@ class Base(object):
""" The basic pySAML2 service provider class """ """ The basic pySAML2 service provider class """
def __init__(self, config=None, identity_cache=None, state_cache=None, def __init__(self, config=None, identity_cache=None, state_cache=None,
virtual_organization=None, config_file=""): virtual_organization="",config_file=""):
""" """
:param config: A saml2.config.Config instance :param config: A saml2.config.Config instance
:param identity_cache: Where the class should store identity information :param identity_cache: Where the class should store identity information
:param state_cache: Where the class should keep state information :param state_cache: Where the class should keep state information
:param virtual_organization: Which if any virtual organization this :param virtual_organization: A specific virtual organization
SP belongs to
""" """
self.users = Population(identity_cache) self.users = Population(identity_cache)
@@ -107,6 +106,10 @@ class Base(object):
else: else:
raise Exception("Missing configuration") raise Exception("Missing configuration")
if self.config.vorg:
for vo in self.config.vorg.values():
vo.sp = self
self.metadata = self.config.metadata self.metadata = self.config.metadata
self.config.setup_logger() self.config.setup_logger()
@@ -118,7 +121,10 @@ class Base(object):
self.sec = security_context(self.config) self.sec = security_context(self.config)
if virtual_organization: if virtual_organization:
self.vorg = VirtualOrg(self, virtual_organization) if isinstance(virtual_organization, basestring):
self.vorg = self.config.vorg[virtual_organization]
elif isinstance(virtual_organization, VirtualOrg):
self.vorg = virtual_organization
else: else:
self.vorg = None self.vorg = None

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
from saml2.virtual_org import VirtualOrg
__author__ = 'rolandh' __author__ = 'rolandh'
@@ -113,7 +114,7 @@ class Config(object):
self.name_form=None self.name_form=None
self.virtual_organization=None self.virtual_organization=None
self.logger=None self.logger=None
self.only_use_keys_in_metadata=None self.only_use_keys_in_metadata=True
self.logout_requests_signed=None self.logout_requests_signed=None
self.disable_ssl_certificate_validation=None self.disable_ssl_certificate_validation=None
self.context = "" self.context = ""
@@ -121,6 +122,7 @@ class Config(object):
self.metadata=None self.metadata=None
self.policy=None self.policy=None
self.serves = [] self.serves = []
self.vorg = {}
# def copy_into(self, typ=""): # def copy_into(self, typ=""):
# if typ == "sp": # if typ == "sp":
@@ -200,6 +202,13 @@ class Config(object):
:return: The Configuration instance :return: The Configuration instance
""" """
for arg in COMMON_ARGS: for arg in COMMON_ARGS:
if arg == "virtual_organization":
if "virtual_organization" in cnf:
for key,val in cnf["virtual_organization"].items():
self.vorg[key] = VirtualOrg(None, key, val)
continue
try: try:
setattr(self, arg, cnf[arg]) setattr(self, arg, cnf[arg])
except KeyError: except KeyError:

View File

@@ -269,7 +269,7 @@ class AuthnResponse(StatusResponse):
elif self.allow_unsolicited: elif self.allow_unsolicited:
pass pass
else: else:
logger("Unsolicited response") logger.exception("Unsolicited response")
raise Exception("Unsolicited response") raise Exception("Unsolicited response")
return self return self

View File

@@ -110,22 +110,18 @@ class Identifier(object):
return temp_id return temp_id
def _get_vo_identifier(self, sp_name_qualifier, userid, identity): def _get_vo_identifier(self, sp_name_qualifier, identity):
try: try:
vo_conf = self.voconf[sp_name_qualifier] vo = self.voconf[sp_name_qualifier]
if "common_identifier" in vo_conf: try:
try: subj_id = identity[vo.common_identifier]
subj_id = identity[vo_conf["common_identifier"]] except KeyError:
except KeyError: raise MissingValue("Common identifier")
raise MissingValue("Common identifier")
else:
return self.persistent_nameid(sp_name_qualifier, userid)
except (KeyError, TypeError): except (KeyError, TypeError):
raise UnknownVO("%s" % sp_name_qualifier) raise UnknownVO("%s" % sp_name_qualifier)
try: nameid_format = vo.nameid_format
nameid_format = vo_conf["nameid_format"] if not nameid_format:
except KeyError:
nameid_format = saml.NAMEID_FORMAT_PERSISTENT nameid_format = saml.NAMEID_FORMAT_PERSISTENT
return saml.NameID(format=nameid_format, return saml.NameID(format=nameid_format,
@@ -189,9 +185,9 @@ class Identifier(object):
if name_id_policy and name_id_policy.sp_name_qualifier: if name_id_policy and name_id_policy.sp_name_qualifier:
try: try:
return self._get_vo_identifier(name_id_policy.sp_name_qualifier, return self._get_vo_identifier(name_id_policy.sp_name_qualifier,
userid, identity) identity)
except Exception: except Exception, exc:
pass print >> sys.stderr, "%s:%s" % (exc.__class__.__name__, exc)
if sp_nid: if sp_nid:
nameid_format = sp_nid[0] nameid_format = sp_nid[0]

View File

@@ -1,15 +1,23 @@
import logging import logging
from saml2.attribute_resolver import AttributeResolver from saml2.attribute_resolver import AttributeResolver
from saml2.saml import NAMEID_FORMAT_PERSISTENT
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class VirtualOrg(object): class VirtualOrg(object):
def __init__(self, sp, vorg): def __init__(self, sp, vorg, cnf):
self.sp = sp # The parent SP client instance self.sp = sp # The parent SP client instance
self.config = sp.config self._name = vorg
self.vorg_name = vorg self.common_identifier = cnf["common_identifier"]
self.vorg_conf = self.config.vo_conf(self.vorg_name) try:
self.member = cnf["member"]
except KeyError:
self.member = []
try:
self.nameid_format = cnf["nameid_format"]
except KeyError:
self.nameid_format = NAMEID_FORMAT_PERSISTENT
def _cache_session(self, session_info): def _cache_session(self, session_info):
return True return True
@@ -18,25 +26,15 @@ class VirtualOrg(object):
Get the member of the Virtual Organization from the metadata, Get the member of the Virtual Organization from the metadata,
more specifically from AffiliationDescriptor. more specifically from AffiliationDescriptor.
""" """
return self.config.metadata.vo_members(self.vorg_name) return self.sp.config.metadata.vo_members(self._name)
def _vo_conf_members(self):
"""
Get the member of the Virtual Organization from the configuration.
"""
try:
return self.vorg_conf["member"]
except (KeyError, TypeError):
return []
def members_to_ask(self, subject_id): def members_to_ask(self, subject_id):
"""Find the member of the Virtual Organization that I haven't already """Find the member of the Virtual Organization that I haven't already
spoken too spoken too
""" """
vo_members = self._affiliation_members() vo_members = self._affiliation_members()
for member in self._vo_conf_members(): for member in self.member:
if member not in vo_members: if member not in vo_members:
vo_members.append(member) vo_members.append(member)
@@ -51,7 +49,7 @@ class VirtualOrg(object):
if ava == {}: if ava == {}:
return None return None
ident = self.vorg_conf["common_identifier"] ident = self.common_identifier
try: try:
return ava[ident][0] return ava[ident][0]
@@ -61,16 +59,16 @@ class VirtualOrg(object):
def do_aggregation(self, subject_id): def do_aggregation(self, subject_id):
logger.info("** Do VO aggregation **\nSubjectID: %s, VO:%s" % ( logger.info("** Do VO aggregation **\nSubjectID: %s, VO:%s" % (
subject_id, self.vorg_name)) subject_id, self._name))
to_ask = self.members_to_ask(subject_id) to_ask = self.members_to_ask(subject_id)
if to_ask: if to_ask:
# Find the NameIDFormat and the SPNameQualifier # Find the NameIDFormat and the SPNameQualifier
if self.vorg_conf and "nameid_format" in self.vorg_conf: if self.nameid_format:
name_id_format = self.vorg_conf["nameid_format"] name_id_format = self.nameid_format
sp_name_qualifier = "" sp_name_qualifier = ""
else: else:
sp_name_qualifier = self.vorg_name sp_name_qualifier = self._name
name_id_format = "" name_id_format = ""
com_identifier = self.get_common_identifier(subject_id) com_identifier = self.get_common_identifier(subject_id)

View File

@@ -192,7 +192,7 @@ def test_2():
assert len(c._sp_idp) == 1 assert len(c._sp_idp) == 1
assert c._sp_idp.keys() == [""] assert c._sp_idp.keys() == [""]
assert c._sp_idp.values() == ["https://example.com/saml2/idp/SSOService.php"] assert c._sp_idp.values() == ["https://example.com/saml2/idp/SSOService.php"]
assert c.only_use_keys_in_metadata is None assert c.only_use_keys_in_metadata is True
def test_minimum(): def test_minimum():
minimum = { minimum = {

View File

@@ -6,7 +6,6 @@ from saml2.config import IdPConfig
from saml2.server import Identifier from saml2.server import Identifier
from saml2.assertion import Policy from saml2.assertion import Policy
def _eq(l1,l2): def _eq(l1,l2):
return set(l1) == set(l2) return set(l1) == set(l2)
@@ -34,6 +33,8 @@ CONFIG = IdPConfig().load({
"common_identifier": "uid", "common_identifier": "uid",
}, },
"http://vo.example.org/design":{ "http://vo.example.org/design":{
"nameid_format" : NAMEID_FORMAT_PERSISTENT,
"common_identifier": "uid",
} }
} }
}) })
@@ -53,7 +54,7 @@ NAME_ID_POLICY_2 = """<?xml version="1.0" encoding="utf-8"?>
class TestIdentifier(): class TestIdentifier():
def setup_class(self): def setup_class(self):
self.id = Identifier("subject.db", CONFIG.virtual_organization) self.id = Identifier("subject.db", CONFIG.vorg)
def test_persistent_1(self): def test_persistent_1(self):
policy = Policy({ policy = Policy({
@@ -109,16 +110,17 @@ class TestIdentifier():
}) })
name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_1) name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_1)
print name_id_policy
print self.id.voconf
nameid = self.id.construct_nameid(policy, "foobar", nameid = self.id.construct_nameid(policy, "foobar",
"urn:mace:example.com:sp:1", "urn:mace:example.com:sp:1",
{"uid": "foobar01"}, {"uid": "foobar01"},
name_id_policy) name_id_policy)
print nameid
assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format']) assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format'])
assert nameid.sp_name_qualifier == 'http://vo.example.org/biomed' assert nameid.sp_name_qualifier == 'http://vo.example.org/biomed'
assert nameid.format == \ assert nameid.format == 'urn:oid:2.16.756.1.2.5.1.1.1-NameID'
CONFIG.virtual_organization['http://vo.example.org/biomed'][
"nameid_format"]
assert nameid.text == "foobar01" assert nameid.text == "foobar01"
def test_vo_2(self): def test_vo_2(self):
@@ -142,5 +144,5 @@ class TestIdentifier():
assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format']) assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format'])
assert nameid.sp_name_qualifier == 'http://vo.example.org/design' assert nameid.sp_name_qualifier == 'http://vo.example.org/design'
assert nameid.format == NAMEID_FORMAT_PERSISTENT assert nameid.format == NAMEID_FORMAT_PERSISTENT
assert nameid.text != "foobar01" assert nameid.text == "foobar01"

View File

@@ -45,6 +45,7 @@ class TestAuthnResponse:
policy=policy) policy=policy)
self.conf = config_factory("sp", "server_conf") self.conf = config_factory("sp", "server_conf")
self.conf.only_use_keys_in_metadata = False
self.ar = authn_response(self.conf, "http://lingon.catalogix.se:8087/") self.ar = authn_response(self.conf, "http://lingon.catalogix.se:8087/")
def test_verify_1(self): def test_verify_1(self):

View File

@@ -201,7 +201,7 @@ class TestClient:
assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT
def test_create_auth_request_vo(self): def test_create_auth_request_vo(self):
assert self.client.config.virtual_organization.keys() == [ assert self.client.config.vorg.keys() == [
"urn:mace:example.com:it:tek"] "urn:mace:example.com:it:tek"]
ar_str = "%s" % self.client.create_authn_request( ar_str = "%s" % self.client.create_authn_request(
@@ -337,7 +337,7 @@ class TestClient:
def test_authenticate(self): def test_authenticate(self):
print self.client.config.idps() print self.client.config.idps()
response = self.client.do_authenticate( id, response = self.client.do_authenticate(
"urn:mace:example.com:saml:roland:idp", "urn:mace:example.com:saml:roland:idp",
"http://www.example.com/relay_state") "http://www.example.com/relay_state")
assert response[0] == "Location" assert response[0] == "Location"
@@ -349,7 +349,7 @@ class TestClient:
authnreq = samlp.authn_request_from_string(saml_request) authnreq = samlp.authn_request_from_string(saml_request)
def test_authenticate_no_args(self): def test_authenticate_no_args(self):
response = self.client.do_authenticate(relay_state="http://www.example.com/relay_state") id, response = self.client.do_authenticate(relay_state="http://www.example.com/relay_state")
assert response[0] == "Location" assert response[0] == "Location"
o = urlparse(response[1]) o = urlparse(response[1])
qdict = parse_qs(o.query) qdict = parse_qs(o.query)

View File

@@ -24,8 +24,8 @@ class TestVirtualOrg():
conf.load_file("server_conf") conf.load_file("server_conf")
self.sp = Saml2Client(conf) self.sp = Saml2Client(conf)
vo_name = conf.virtual_organization.keys()[0] vo_name = conf.vorg.keys()[0]
self.vo = VirtualOrg(self.sp, vo_name) self.vo = conf.vorg[vo_name]
add_derek_info(self.sp) add_derek_info(self.sp)
def test_mta(self): def test_mta(self):
@@ -53,7 +53,7 @@ class TestVirtualOrg_2():
def setup_class(self): def setup_class(self):
conf = config.SPConfig() conf = config.SPConfig()
conf.load_file("server_conf") conf.load_file("server_conf")
vo_name = conf.virtual_organization.keys()[0] vo_name = conf.vorg.keys()[0]
self.sp = Saml2Client(conf, virtual_organization=vo_name) self.sp = Saml2Client(conf, virtual_organization=vo_name)
add_derek_info(self.sp) add_derek_info(self.sp)