This commit is contained in:
Roland Hedberg
2012-11-14 17:34:24 +01:00
parent 74cf8659e1
commit 3db08b124e
13 changed files with 95 additions and 71 deletions

8
.idea/libraries/sass_stdlib.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<component name="libraryTable">
<library name="sass-stdlib">
<CLASSES />
<SOURCES>
<root url="file://$APPLICATION_HOME_DIR$/plugins/sass/lib/stubs/sass_functions.scss" />
</SOURCES>
</library>
</component>

View File

@@ -40,6 +40,8 @@ from repoze.who.plugins.form import FormPluginBase
from saml2 import ecp
from saml2.client import Saml2Client
from saml2.discovery import discovery_service_response
from saml2.discovery import discovery_service_request_url
from saml2.s_utils import sid
from saml2.config import config_factory
from saml2.profile import paos
@@ -111,6 +113,8 @@ class ECP_response(object):
[('Content-Type', "text/xml")])
return [self.content]
class SAML2Plugin(FormPluginBase):
implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider)
@@ -242,14 +246,14 @@ class SAML2Plugin(FormPluginBase):
return self._wayf_redirect(came_from)
elif self.discovery:
if query:
idp_entity_id = self.saml_client.get_idp_from_discovery_service(
idp_entity_id = discovery_service_response(
query=environ.get("QUERY_STRING"))
else:
sid_ = sid()
self.outstanding_queries[sid_] = came_from
logger.info("Redirect to Discovery Service function")
loc = self.saml_client.request_to_discovery_service(
self.discovery)
eid = self.saml_client.config.entity_id
loc = discovery_service_request_url(eid, self.discovery)
return -1, HTTPSeeOther(headers = [('Location',loc)])
else:
return -1, HTTPNotImplemented(detail='No WAYF or DJ present!')
@@ -273,13 +277,13 @@ class SAML2Plugin(FormPluginBase):
environ["myapp.came_from"] = came_from
logger.debug("[sp.challenge] RelayState >> %s" % came_from)
# Am I part of a virtual organization ?
# Am I part of a virtual organization or more than one ?
try:
vorg_name = environ["myapp.vo"]
except KeyError:
try:
vorg_name = self.saml_client.vorg.vorg_name
except AttributeError:
vorg_name = self.saml_client.vorg.keys()[1]
except IndexError:
vorg_name = ""
logger.info("[sp.challenge] VO: %s" % vorg_name)
@@ -300,7 +304,7 @@ class SAML2Plugin(FormPluginBase):
logger.info("[sp.challenge] idp_url: %s" % idp_url)
# Do the AuthnRequest
(sid_, result) = self.saml_client.authenticate(idp_url,
sid_, result = self.saml_client.do_authenticate(idp_url,
relay_state=came_from,
vorg=vorg_name)
@@ -466,9 +470,9 @@ class SAML2Plugin(FormPluginBase):
if "pysaml2_vo_expanded" not in identity:
# is this a Virtual Organization situation
if self.saml_client.vorg:
for vo in self.saml_client.vorg.values():
try:
if self.saml_client.vorg.do_aggregation(subject_id):
if vo.do_aggregation(subject_id):
# Get the extended identity
identity["user"] = self.saml_client.users.get_identity(
subject_id)[0]

View File

@@ -90,7 +90,7 @@ class Saml2Client(Base):
else:
raise Exception("Unknown binding type: %s" % binding)
return response
return req.id, response
def global_logout(self, subject_id, reason="", expire=None, sign=None,
return_to="/"):

View File

@@ -83,13 +83,12 @@ class Base(object):
""" The basic pySAML2 service provider class """
def __init__(self, config=None, identity_cache=None, state_cache=None,
virtual_organization=None, config_file=""):
virtual_organization="",config_file=""):
"""
:param config: A saml2.config.Config instance
:param identity_cache: Where the class should store identity information
:param state_cache: Where the class should keep state information
:param virtual_organization: Which if any virtual organization this
SP belongs to
:param virtual_organization: A specific virtual organization
"""
self.users = Population(identity_cache)
@@ -107,6 +106,10 @@ class Base(object):
else:
raise Exception("Missing configuration")
if self.config.vorg:
for vo in self.config.vorg.values():
vo.sp = self
self.metadata = self.config.metadata
self.config.setup_logger()
@@ -118,7 +121,10 @@ class Base(object):
self.sec = security_context(self.config)
if virtual_organization:
self.vorg = VirtualOrg(self, virtual_organization)
if isinstance(virtual_organization, basestring):
self.vorg = self.config.vorg[virtual_organization]
elif isinstance(virtual_organization, VirtualOrg):
self.vorg = virtual_organization
else:
self.vorg = None

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python
from saml2.virtual_org import VirtualOrg
__author__ = 'rolandh'
@@ -113,7 +114,7 @@ class Config(object):
self.name_form=None
self.virtual_organization=None
self.logger=None
self.only_use_keys_in_metadata=None
self.only_use_keys_in_metadata=True
self.logout_requests_signed=None
self.disable_ssl_certificate_validation=None
self.context = ""
@@ -121,6 +122,7 @@ class Config(object):
self.metadata=None
self.policy=None
self.serves = []
self.vorg = {}
# def copy_into(self, typ=""):
# if typ == "sp":
@@ -200,6 +202,13 @@ class Config(object):
:return: The Configuration instance
"""
for arg in COMMON_ARGS:
if arg == "virtual_organization":
if "virtual_organization" in cnf:
for key,val in cnf["virtual_organization"].items():
self.vorg[key] = VirtualOrg(None, key, val)
continue
try:
setattr(self, arg, cnf[arg])
except KeyError:

View File

@@ -269,7 +269,7 @@ class AuthnResponse(StatusResponse):
elif self.allow_unsolicited:
pass
else:
logger("Unsolicited response")
logger.exception("Unsolicited response")
raise Exception("Unsolicited response")
return self

View File

@@ -110,22 +110,18 @@ class Identifier(object):
return temp_id
def _get_vo_identifier(self, sp_name_qualifier, userid, identity):
def _get_vo_identifier(self, sp_name_qualifier, identity):
try:
vo_conf = self.voconf[sp_name_qualifier]
if "common_identifier" in vo_conf:
vo = self.voconf[sp_name_qualifier]
try:
subj_id = identity[vo_conf["common_identifier"]]
subj_id = identity[vo.common_identifier]
except KeyError:
raise MissingValue("Common identifier")
else:
return self.persistent_nameid(sp_name_qualifier, userid)
except (KeyError, TypeError):
raise UnknownVO("%s" % sp_name_qualifier)
try:
nameid_format = vo_conf["nameid_format"]
except KeyError:
nameid_format = vo.nameid_format
if not nameid_format:
nameid_format = saml.NAMEID_FORMAT_PERSISTENT
return saml.NameID(format=nameid_format,
@@ -189,9 +185,9 @@ class Identifier(object):
if name_id_policy and name_id_policy.sp_name_qualifier:
try:
return self._get_vo_identifier(name_id_policy.sp_name_qualifier,
userid, identity)
except Exception:
pass
identity)
except Exception, exc:
print >> sys.stderr, "%s:%s" % (exc.__class__.__name__, exc)
if sp_nid:
nameid_format = sp_nid[0]

View File

@@ -1,14 +1,22 @@
import logging
from saml2.attribute_resolver import AttributeResolver
from saml2.saml import NAMEID_FORMAT_PERSISTENT
logger = logging.getLogger(__name__)
class VirtualOrg(object):
def __init__(self, sp, vorg):
def __init__(self, sp, vorg, cnf):
self.sp = sp # The parent SP client instance
self.config = sp.config
self.vorg_name = vorg
self.vorg_conf = self.config.vo_conf(self.vorg_name)
self._name = vorg
self.common_identifier = cnf["common_identifier"]
try:
self.member = cnf["member"]
except KeyError:
self.member = []
try:
self.nameid_format = cnf["nameid_format"]
except KeyError:
self.nameid_format = NAMEID_FORMAT_PERSISTENT
def _cache_session(self, session_info):
return True
@@ -18,17 +26,7 @@ class VirtualOrg(object):
Get the member of the Virtual Organization from the metadata,
more specifically from AffiliationDescriptor.
"""
return self.config.metadata.vo_members(self.vorg_name)
def _vo_conf_members(self):
"""
Get the member of the Virtual Organization from the configuration.
"""
try:
return self.vorg_conf["member"]
except (KeyError, TypeError):
return []
return self.sp.config.metadata.vo_members(self._name)
def members_to_ask(self, subject_id):
"""Find the member of the Virtual Organization that I haven't already
@@ -36,7 +34,7 @@ class VirtualOrg(object):
"""
vo_members = self._affiliation_members()
for member in self._vo_conf_members():
for member in self.member:
if member not in vo_members:
vo_members.append(member)
@@ -51,7 +49,7 @@ class VirtualOrg(object):
if ava == {}:
return None
ident = self.vorg_conf["common_identifier"]
ident = self.common_identifier
try:
return ava[ident][0]
@@ -61,16 +59,16 @@ class VirtualOrg(object):
def do_aggregation(self, subject_id):
logger.info("** Do VO aggregation **\nSubjectID: %s, VO:%s" % (
subject_id, self.vorg_name))
subject_id, self._name))
to_ask = self.members_to_ask(subject_id)
if to_ask:
# Find the NameIDFormat and the SPNameQualifier
if self.vorg_conf and "nameid_format" in self.vorg_conf:
name_id_format = self.vorg_conf["nameid_format"]
if self.nameid_format:
name_id_format = self.nameid_format
sp_name_qualifier = ""
else:
sp_name_qualifier = self.vorg_name
sp_name_qualifier = self._name
name_id_format = ""
com_identifier = self.get_common_identifier(subject_id)

View File

@@ -192,7 +192,7 @@ def test_2():
assert len(c._sp_idp) == 1
assert c._sp_idp.keys() == [""]
assert c._sp_idp.values() == ["https://example.com/saml2/idp/SSOService.php"]
assert c.only_use_keys_in_metadata is None
assert c.only_use_keys_in_metadata is True
def test_minimum():
minimum = {

View File

@@ -6,7 +6,6 @@ from saml2.config import IdPConfig
from saml2.server import Identifier
from saml2.assertion import Policy
def _eq(l1,l2):
return set(l1) == set(l2)
@@ -34,6 +33,8 @@ CONFIG = IdPConfig().load({
"common_identifier": "uid",
},
"http://vo.example.org/design":{
"nameid_format" : NAMEID_FORMAT_PERSISTENT,
"common_identifier": "uid",
}
}
})
@@ -53,7 +54,7 @@ NAME_ID_POLICY_2 = """<?xml version="1.0" encoding="utf-8"?>
class TestIdentifier():
def setup_class(self):
self.id = Identifier("subject.db", CONFIG.virtual_organization)
self.id = Identifier("subject.db", CONFIG.vorg)
def test_persistent_1(self):
policy = Policy({
@@ -109,16 +110,17 @@ class TestIdentifier():
})
name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_1)
print name_id_policy
print self.id.voconf
nameid = self.id.construct_nameid(policy, "foobar",
"urn:mace:example.com:sp:1",
{"uid": "foobar01"},
name_id_policy)
print nameid
assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format'])
assert nameid.sp_name_qualifier == 'http://vo.example.org/biomed'
assert nameid.format == \
CONFIG.virtual_organization['http://vo.example.org/biomed'][
"nameid_format"]
assert nameid.format == 'urn:oid:2.16.756.1.2.5.1.1.1-NameID'
assert nameid.text == "foobar01"
def test_vo_2(self):
@@ -142,5 +144,5 @@ class TestIdentifier():
assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format'])
assert nameid.sp_name_qualifier == 'http://vo.example.org/design'
assert nameid.format == NAMEID_FORMAT_PERSISTENT
assert nameid.text != "foobar01"
assert nameid.text == "foobar01"

View File

@@ -45,6 +45,7 @@ class TestAuthnResponse:
policy=policy)
self.conf = config_factory("sp", "server_conf")
self.conf.only_use_keys_in_metadata = False
self.ar = authn_response(self.conf, "http://lingon.catalogix.se:8087/")
def test_verify_1(self):

View File

@@ -201,7 +201,7 @@ class TestClient:
assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT
def test_create_auth_request_vo(self):
assert self.client.config.virtual_organization.keys() == [
assert self.client.config.vorg.keys() == [
"urn:mace:example.com:it:tek"]
ar_str = "%s" % self.client.create_authn_request(
@@ -337,7 +337,7 @@ class TestClient:
def test_authenticate(self):
print self.client.config.idps()
response = self.client.do_authenticate(
id, response = self.client.do_authenticate(
"urn:mace:example.com:saml:roland:idp",
"http://www.example.com/relay_state")
assert response[0] == "Location"
@@ -349,7 +349,7 @@ class TestClient:
authnreq = samlp.authn_request_from_string(saml_request)
def test_authenticate_no_args(self):
response = self.client.do_authenticate(relay_state="http://www.example.com/relay_state")
id, response = self.client.do_authenticate(relay_state="http://www.example.com/relay_state")
assert response[0] == "Location"
o = urlparse(response[1])
qdict = parse_qs(o.query)

View File

@@ -24,8 +24,8 @@ class TestVirtualOrg():
conf.load_file("server_conf")
self.sp = Saml2Client(conf)
vo_name = conf.virtual_organization.keys()[0]
self.vo = VirtualOrg(self.sp, vo_name)
vo_name = conf.vorg.keys()[0]
self.vo = conf.vorg[vo_name]
add_derek_info(self.sp)
def test_mta(self):
@@ -53,7 +53,7 @@ class TestVirtualOrg_2():
def setup_class(self):
conf = config.SPConfig()
conf.load_file("server_conf")
vo_name = conf.virtual_organization.keys()[0]
vo_name = conf.vorg.keys()[0]
self.sp = Saml2Client(conf, virtual_organization=vo_name)
add_derek_info(self.sp)