Change the config module, resulted in changes in many files
This commit is contained in:
@@ -34,11 +34,8 @@ from repoze.who.interfaces import IChallenger, IIdentifier, IAuthenticator
|
|||||||
from repoze.who.interfaces import IMetadataProvider
|
from repoze.who.interfaces import IMetadataProvider
|
||||||
from repoze.who.plugins.form import FormPluginBase
|
from repoze.who.plugins.form import FormPluginBase
|
||||||
|
|
||||||
from saml2 import BINDING_HTTP_REDIRECT
|
|
||||||
from saml2.client import Saml2Client
|
from saml2.client import Saml2Client
|
||||||
from saml2.config import SPConfig
|
|
||||||
from saml2.s_utils import sid
|
from saml2.s_utils import sid
|
||||||
from saml2.virtual_org import VirtualOrg
|
|
||||||
|
|
||||||
#from saml2.population import Population
|
#from saml2.population import Population
|
||||||
#from saml2.attribute_resolver import AttributeResolver
|
#from saml2.attribute_resolver import AttributeResolver
|
||||||
@@ -417,13 +414,10 @@ def make_plugin(rememberer_name=None, # plugin for remember
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'must include rememberer_name in configuration')
|
'must include rememberer_name in configuration')
|
||||||
|
|
||||||
config = SPConfig()
|
scl = Saml2Client(config_file=saml_conf, identity_cache=identity_cache,
|
||||||
config.load_file(saml_conf)
|
|
||||||
|
|
||||||
scl = Saml2Client(config, identity_cache=identity_cache,
|
|
||||||
virtual_organization=virtual_organization)
|
virtual_organization=virtual_organization)
|
||||||
|
|
||||||
plugin = SAML2Plugin(rememberer_name, config, scl, wayf, cache, debug,
|
plugin = SAML2Plugin(rememberer_name, scl.config, scl, wayf, cache, debug,
|
||||||
sid_store)
|
sid_store)
|
||||||
return plugin
|
return plugin
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from saml2.binding import send_using_soap, http_redirect_message
|
|||||||
from saml2.binding import http_post_message
|
from saml2.binding import http_post_message
|
||||||
from saml2.population import Population
|
from saml2.population import Population
|
||||||
from saml2.virtual_org import VirtualOrg
|
from saml2.virtual_org import VirtualOrg
|
||||||
from saml2.config import SPConfig
|
from saml2.config import config_factory
|
||||||
|
|
||||||
#from saml2.response import authn_response
|
#from saml2.response import authn_response
|
||||||
from saml2.response import response_factory
|
from saml2.response import response_factory
|
||||||
@@ -72,9 +72,9 @@ class LogoutError(Exception):
|
|||||||
class Saml2Client(object):
|
class Saml2Client(object):
|
||||||
""" The basic pySAML2 service provider class """
|
""" The basic pySAML2 service provider class """
|
||||||
|
|
||||||
def __init__(self, config=None, debug=0,
|
def __init__(self, config=None, debug=0,
|
||||||
identity_cache=None, state_cache=None,
|
identity_cache=None, state_cache=None,
|
||||||
virtual_organization=None):
|
virtual_organization=None, config_file=""):
|
||||||
"""
|
"""
|
||||||
:param config: A saml2.config.Config instance
|
:param config: A saml2.config.Config instance
|
||||||
:param debug: Whether debugging should be done even if the
|
:param debug: Whether debugging should be done even if the
|
||||||
@@ -96,11 +96,14 @@ class Saml2Client(object):
|
|||||||
self.sec = None
|
self.sec = None
|
||||||
if config:
|
if config:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.metadata = config.metadata
|
elif config_file:
|
||||||
self.sec = security_context(config)
|
self.config = config_factory("sp", config_file)
|
||||||
else:
|
else:
|
||||||
self.config = SPConfig()
|
raise Exception("Missing configuration")
|
||||||
|
|
||||||
|
self.metadata = self.config.metadata
|
||||||
|
self.sec = security_context(config)
|
||||||
|
|
||||||
if virtual_organization:
|
if virtual_organization:
|
||||||
self.vorg = VirtualOrg(self, virtual_organization)
|
self.vorg = VirtualOrg(self, virtual_organization)
|
||||||
else:
|
else:
|
||||||
@@ -464,12 +467,15 @@ class Saml2Client(object):
|
|||||||
log.info("No response")
|
log.info("No response")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def construct_logout_request(self, subject_id, destination, entity_id,
|
def construct_logout_request(self, subject_id, destination,
|
||||||
|
issuer_entity_id,
|
||||||
reason=None, expire=None):
|
reason=None, expire=None):
|
||||||
""" Constructs a LogoutRequest
|
""" Constructs a LogoutRequest
|
||||||
|
|
||||||
:param subject_id: The identifier of the subject
|
:param subject_id: The identifier of the subject
|
||||||
:param destination:
|
:param destination:
|
||||||
|
:param issuer_entity_id: The entity ID of the IdP the request is
|
||||||
|
target at.
|
||||||
:param reason: An indication of the reason for the logout, in the
|
:param reason: An indication of the reason for the logout, in the
|
||||||
form of a URI reference.
|
form of a URI reference.
|
||||||
:param expire: The time at which the request expires,
|
:param expire: The time at which the request expires,
|
||||||
@@ -480,7 +486,8 @@ class Saml2Client(object):
|
|||||||
session_id = sid()
|
session_id = sid()
|
||||||
# create NameID from subject_id
|
# create NameID from subject_id
|
||||||
name_id = saml.NameID(
|
name_id = saml.NameID(
|
||||||
text = self.users.get_entityid(subject_id, entity_id, False))
|
text = self.users.get_entityid(subject_id, issuer_entity_id,
|
||||||
|
False))
|
||||||
|
|
||||||
request = samlp.LogoutRequest(
|
request = samlp.LogoutRequest(
|
||||||
id=session_id,
|
id=session_id,
|
||||||
|
|||||||
@@ -2,65 +2,135 @@
|
|||||||
|
|
||||||
__author__ = 'rolandh'
|
__author__ = 'rolandh'
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from importlib import import_module
|
||||||
from saml2 import BINDING_SOAP, BINDING_HTTP_REDIRECT
|
from saml2 import BINDING_SOAP, BINDING_HTTP_REDIRECT
|
||||||
from saml2 import metadata
|
from saml2 import metadata
|
||||||
from saml2.attribute_converter import ac_factory
|
from saml2.attribute_converter import ac_factory
|
||||||
from saml2.assertion import Policy
|
from saml2.assertion import Policy
|
||||||
|
|
||||||
SIMPLE_ARGS = ["entityid", "xmlsec_binary", "debug", "key_file", "cert_file",
|
COMMON_ARGS = ["entityid", "xmlsec_binary", "debug", "key_file", "cert_file",
|
||||||
"secret", "accepted_time_diff", "virtual_organization", "name",
|
"secret", "accepted_time_diff", "name",
|
||||||
"description", "endpoints", "required_attributes",
|
"description",
|
||||||
"optional_attributes", "idp", "sp", "aa", "subject_data",
|
"organization",
|
||||||
"want_assertions_signed", "authn_requests_signed", "type",
|
"contact_person",
|
||||||
"organization", "contact_person",
|
"name_form",
|
||||||
"want_authn_requests_signed", "name_form"]
|
"virtual_organization",
|
||||||
|
]
|
||||||
|
|
||||||
COMPLEX_ARGS = ["metadata", "attribute_converters", "policy"]
|
SP_ARGS = [
|
||||||
|
"required_attributes",
|
||||||
|
"optional_attributes",
|
||||||
|
"idp",
|
||||||
|
"subject_data",
|
||||||
|
"want_assertions_signed",
|
||||||
|
"authn_requests_signed",
|
||||||
|
"name_form",
|
||||||
|
"endpoints",
|
||||||
|
]
|
||||||
|
|
||||||
|
AA_IDP_ARGS = ["want_authn_requests_signed",
|
||||||
|
"provided_attributes",
|
||||||
|
"subject_data",
|
||||||
|
"sp",
|
||||||
|
"endpoints",
|
||||||
|
"metadata"]
|
||||||
|
|
||||||
|
COMPLEX_ARGS = ["attribute_converters", "metadata", "policy"]
|
||||||
|
ALL = COMMON_ARGS + SP_ARGS + AA_IDP_ARGS + COMPLEX_ARGS
|
||||||
|
|
||||||
|
|
||||||
|
SPEC = {
|
||||||
|
"": COMMON_ARGS + COMPLEX_ARGS,
|
||||||
|
"sp": COMMON_ARGS + COMPLEX_ARGS + SP_ARGS,
|
||||||
|
"idp": COMMON_ARGS + COMPLEX_ARGS + AA_IDP_ARGS,
|
||||||
|
"aa": COMMON_ARGS + COMPLEX_ARGS + AA_IDP_ARGS,
|
||||||
|
}
|
||||||
|
|
||||||
class Config(object):
|
class Config(object):
|
||||||
def __init__(self):
|
def_context = ""
|
||||||
self._attr = {}
|
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._attr = {"": {}, "sp": {}, "idp": {}, "aa": {}}
|
||||||
|
self.context = ""
|
||||||
|
|
||||||
|
def serves(self):
|
||||||
|
return [t for t in ["sp", "idp", "aa"] if self._attr[t]]
|
||||||
|
|
||||||
def __getattribute__(self, item):
|
def __getattribute__(self, item):
|
||||||
if item in SIMPLE_ARGS or item in COMPLEX_ARGS:
|
if item == "context":
|
||||||
|
return object.__getattribute__(self, item)
|
||||||
|
|
||||||
|
_context = self.context
|
||||||
|
if item in ALL:
|
||||||
try:
|
try:
|
||||||
return self._attr[item]
|
return self._attr[_context][item]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
if _context:
|
||||||
|
try:
|
||||||
|
return self._attr[""][item]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return object.__getattribute__(self, item)
|
return object.__getattribute__(self, item)
|
||||||
|
|
||||||
def load(self, cnf):
|
def load_special(self, cnf, typ):
|
||||||
|
for arg in SPEC[typ]:
|
||||||
for arg in SIMPLE_ARGS:
|
|
||||||
try:
|
try:
|
||||||
self._attr[arg] = cnf[arg]
|
self._attr[typ][arg] = cnf[arg]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.context = typ
|
||||||
|
self.load_complex(cnf, typ)
|
||||||
|
self.context = self.def_context
|
||||||
|
|
||||||
|
def load_complex(self, cnf, typ=""):
|
||||||
try:
|
try:
|
||||||
self._attr["policy"] = Policy(cnf["policy"])
|
self._attr[typ]["policy"] = Policy(cnf["policy"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
acs = ac_factory(cnf["attribute_map_dir"])
|
acs = ac_factory(cnf["attribute_map_dir"])
|
||||||
try:
|
try:
|
||||||
self._attr["attribute_converters"].extend(acs)
|
self._attr[typ]["attribute_converters"].extend(acs)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self._attr["attribute_converters"] = acs
|
self._attr[typ]["attribute_converters"] = acs
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._attr["metadata"] = self.load_metadata(cnf["metadata"])
|
self._attr[typ]["metadata"] = self.load_metadata(cnf["metadata"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def load(self, cnf):
|
||||||
|
|
||||||
|
for arg in COMMON_ARGS:
|
||||||
|
try:
|
||||||
|
self._attr[""][arg] = cnf[arg]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if "service" in cnf:
|
||||||
|
for typ in ["aa", "idp", "sp"]:
|
||||||
|
try:
|
||||||
|
self.load_special(cnf["service"][typ], typ)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.load_complex(cnf)
|
||||||
|
self.context = self.def_context
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def load_file(self, config_file):
|
def load_file(self, config_file):
|
||||||
return self.load(eval(open(config_file).read()))
|
if sys.path[0] != ".":
|
||||||
|
sys.path.insert(0, ".")
|
||||||
|
mod = import_module(config_file)
|
||||||
|
#return self.load(eval(open(config_file).read()))
|
||||||
|
return self.load(mod.CONFIG)
|
||||||
|
|
||||||
def load_metadata(self, metadata_conf):
|
def load_metadata(self, metadata_conf):
|
||||||
""" Loads metadata into an internal structure """
|
""" Loads metadata into an internal structure """
|
||||||
@@ -110,6 +180,8 @@ class Config(object):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
class SPConfig(Config):
|
class SPConfig(Config):
|
||||||
|
def_context = "sp"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Config.__init__(self)
|
Config.__init__(self)
|
||||||
|
|
||||||
@@ -167,6 +239,8 @@ class SPConfig(Config):
|
|||||||
return self.metadata.idps()
|
return self.metadata.idps()
|
||||||
|
|
||||||
class IdPConfig(Config):
|
class IdPConfig(Config):
|
||||||
|
def_context = "idp"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Config.__init__(self)
|
Config.__init__(self)
|
||||||
|
|
||||||
@@ -190,3 +264,15 @@ class IdPConfig(Config):
|
|||||||
return [s[binding] for s in acs]
|
return [s[binding] for s in acs]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def config_factory(typ, file):
|
||||||
|
if typ == "sp":
|
||||||
|
conf = SPConfig().load_file(file)
|
||||||
|
conf.context = typ
|
||||||
|
elif typ in ["aa", "idp"]:
|
||||||
|
conf = IdPConfig().load_file(file)
|
||||||
|
conf.context = typ
|
||||||
|
else:
|
||||||
|
conf = Config().load_file(file)
|
||||||
|
conf.context = typ
|
||||||
|
return conf
|
||||||
|
|||||||
@@ -719,6 +719,7 @@ def do_requested_attribute(attributes, acs, is_required="false"):
|
|||||||
for key in attr.keyswv():
|
for key in attr.keyswv():
|
||||||
args[key] = getattr(attr, key)
|
args[key] = getattr(attr, key)
|
||||||
args["is_required"] = is_required
|
args["is_required"] = is_required
|
||||||
|
args["name_format"] = NAME_FORMAT_URI
|
||||||
lista.append(md.RequestedAttribute(**args))
|
lista.append(md.RequestedAttribute(**args))
|
||||||
return lista
|
return lista
|
||||||
|
|
||||||
@@ -917,16 +918,18 @@ def entity_descriptor(confd, valid_for):
|
|||||||
if confd.contact_person is not None:
|
if confd.contact_person is not None:
|
||||||
entd.contact_person = do_contact_person_info(confd.contact_person)
|
entd.contact_person = do_contact_person_info(confd.contact_person)
|
||||||
|
|
||||||
if confd.type == "sp":
|
serves = confd.serves()
|
||||||
entd.spsso_descriptor = do_sp_sso_descriptor(confd, mycert)
|
if not serves:
|
||||||
elif confd.type == "idp":
|
|
||||||
entd.idpsso_descriptor = do_idp_sso_descriptor(confd, mycert)
|
|
||||||
elif confd.type == "aa":
|
|
||||||
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'No service type ("sp","idp","aa") provided in the configuration')
|
'No service type ("sp","idp","aa") provided in the configuration')
|
||||||
|
|
||||||
|
if "sp" in serves:
|
||||||
|
entd.spsso_descriptor = do_sp_sso_descriptor(confd, mycert)
|
||||||
|
if "idp" in serves:
|
||||||
|
entd.idpsso_descriptor = do_idp_sso_descriptor(confd, mycert)
|
||||||
|
if "aa" in serves:
|
||||||
|
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert)
|
||||||
|
|
||||||
return entd
|
return entd
|
||||||
|
|
||||||
def entities_descriptor(eds, valid_for, name, ident, sign, secc):
|
def entities_descriptor(eds, valid_for, name, ident, sign, secc):
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def for_me(condition, myself ):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def authn_response(conf, entity_id, return_addr, outstanding_queries=None,
|
def authn_response(conf, return_addr, outstanding_queries=None,
|
||||||
log=None, timeslack=0, debug=0):
|
log=None, timeslack=0, debug=0):
|
||||||
sec = security_context(conf)
|
sec = security_context(conf)
|
||||||
if not timeslack:
|
if not timeslack:
|
||||||
@@ -65,12 +65,12 @@ def authn_response(conf, entity_id, return_addr, outstanding_queries=None,
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
timeslack = 0
|
timeslack = 0
|
||||||
|
|
||||||
return AuthnResponse(sec, conf.attribute_converters, entity_id,
|
return AuthnResponse(sec, conf.attribute_converters, conf.entityid,
|
||||||
return_addr, outstanding_queries, log, timeslack,
|
return_addr, outstanding_queries, log, timeslack,
|
||||||
debug)
|
debug)
|
||||||
|
|
||||||
# comes in over SOAP so synchronous
|
# comes in over SOAP so synchronous
|
||||||
def attribute_response(conf, entity_id, return_addr, log=None, timeslack=0,
|
def attribute_response(conf, return_addr, log=None, timeslack=0,
|
||||||
debug=0):
|
debug=0):
|
||||||
sec = security_context(conf)
|
sec = security_context(conf)
|
||||||
if not timeslack:
|
if not timeslack:
|
||||||
@@ -79,7 +79,7 @@ def attribute_response(conf, entity_id, return_addr, log=None, timeslack=0,
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
timeslack = 0
|
timeslack = 0
|
||||||
|
|
||||||
return AttributeResponse(sec, conf.attribute_converters, entity_id,
|
return AttributeResponse(sec, conf.attribute_converters, conf.entityid,
|
||||||
return_addr, log, timeslack, debug)
|
return_addr, log, timeslack, debug)
|
||||||
|
|
||||||
class StatusResponse(object):
|
class StatusResponse(object):
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ from saml2.binding import http_post_message
|
|||||||
from saml2.sigver import security_context
|
from saml2.sigver import security_context
|
||||||
from saml2.sigver import signed_instance_factory
|
from saml2.sigver import signed_instance_factory
|
||||||
from saml2.sigver import pre_signature_part
|
from saml2.sigver import pre_signature_part
|
||||||
from saml2.config import IdPConfig
|
from saml2.config import config_factory
|
||||||
from saml2.assertion import Assertion, Policy
|
from saml2.assertion import Assertion, Policy
|
||||||
|
|
||||||
class UnknownVO(Exception):
|
class UnknownVO(Exception):
|
||||||
@@ -212,6 +212,8 @@ class Server(object):
|
|||||||
self.load_config(config_file)
|
self.load_config(config_file)
|
||||||
elif config:
|
elif config:
|
||||||
self.conf = config
|
self.conf = config
|
||||||
|
else:
|
||||||
|
raise Exception("Missing configuration")
|
||||||
|
|
||||||
self.metadata = self.conf.metadata
|
self.metadata = self.conf.metadata
|
||||||
self.sec = security_context(self.conf, log)
|
self.sec = security_context(self.conf, log)
|
||||||
@@ -228,8 +230,7 @@ class Server(object):
|
|||||||
|
|
||||||
:param config_file: The name of the configuration file
|
:param config_file: The name of the configuration file
|
||||||
"""
|
"""
|
||||||
self.conf = IdPConfig()
|
self.conf = config_factory("idp", config_file)
|
||||||
self.conf.load_file(config_file)
|
|
||||||
try:
|
try:
|
||||||
# subject information is store in database
|
# subject information is store in database
|
||||||
# default database is a shelve database which is OK in some setups
|
# default database is a shelve database which is OK in some setups
|
||||||
@@ -593,7 +594,10 @@ class Server(object):
|
|||||||
self.log.info("enpoints: %s" % (self.conf.endpoints,))
|
self.log.info("enpoints: %s" % (self.conf.endpoints,))
|
||||||
self.log.info("binding wanted: %s" % (binding,))
|
self.log.info("binding wanted: %s" % (binding,))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
if not slo:
|
||||||
|
raise Exception("No single_logout_server for that binding")
|
||||||
|
|
||||||
if self.log:
|
if self.log:
|
||||||
self.log.info("Endpoint: %s" % slo)
|
self.log.info("Endpoint: %s" % slo)
|
||||||
req = LogoutRequest(self.sec, slo)
|
req = LogoutRequest(self.sec, slo)
|
||||||
|
|||||||
Reference in New Issue
Block a user