Refactored the support for metadata extension in the config file.

This commit is contained in:
Roland Hedberg
2014-03-13 15:15:20 +01:00
parent d7f50d4b05
commit 5cd5ebdc03
3 changed files with 87 additions and 15 deletions

View File

@@ -23,6 +23,7 @@ import xmldsig as ds
from saml2.sigver import pre_signature_part
from saml2.s_utils import factory
from saml2.s_utils import rec_factory
from saml2.s_utils import sid
__author__ = 'rolandh'
@@ -51,6 +52,7 @@ ORG_ATTR_TRANSL = {
"organization_url": ("url", md.OrganizationURL)
}
def metadata_tostring_fix(desc, nspair):
MDNS = '"urn:oasis:names:tc:SAML:2.0:metadata"'
XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
@@ -60,7 +62,7 @@ def metadata_tostring_fix(desc, nspair):
return xmlstring
def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
def create_metadata_string(configfile, config, valid, cert, keyfile, mid, name,
sign):
valid_for = 0
nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
@@ -85,8 +87,8 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
conf.xmlsec_binary = config.xmlsec_binary
secc = security_context(conf)
if id:
desc = entities_descriptor(eds, valid_for, name, id,
if mid:
desc = entities_descriptor(eds, valid_for, name, mid,
sign, secc)
valid_instance(desc)
@@ -94,7 +96,7 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
else:
for eid in eds:
if sign:
desc = sign_entity_descriptor(eid, id, secc)
desc = sign_entity_descriptor(eid, mid, secc)
else:
desc = eid
valid_instance(desc)
@@ -372,6 +374,21 @@ DEFAULT_BINDING = {
}
def do_extensions(mname, item):
try:
_mod = __import__("saml2.extension.%s" % mname, globals(), locals(),
mname)
except ImportError:
return None
else:
res = []
for _cname, ava in item.items():
cls = getattr(_mod, _cname)
res.append(rec_factory(cls, **ava))
return res
def _do_nameid_format(cls, conf, typ):
namef = conf.getattr("name_id_format", typ)
if namef:
@@ -421,19 +438,30 @@ def do_spsso_descriptor(conf, cert=None):
spsso = md.SPSSODescriptor()
spsso.protocol_support_enumeration = samlp.NAMESPACE
exts = conf.getattr("extensions", "sp")
if exts:
if spsso.extensions is None:
spsso.extensions = md.Extensions()
for key, val in exts.items():
_ext = do_extensions(key, val)
if _ext:
for _e in _ext:
spsso.extensions.add_extension_element(_e)
endps = conf.getattr("endpoints", "sp")
if endps:
for (endpoint, instlist) in do_endpoints(endps,
ENDPOINTS["sp"]).items():
setattr(spsso, endpoint, instlist)
ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
if ext:
if spsso.extensions is None:
spsso.extensions = md.Extensions()
for vals in ext.values():
for val in vals:
spsso.extensions.add_extension_element(val)
# ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
# if ext:
# if spsso.extensions is None:
# spsso.extensions = md.Extensions()
# for vals in ext.values():
# for val in vals:
# spsso.extensions.add_extension_element(val)
if cert:
encryption_type = conf.encryption_type

View File

@@ -413,7 +413,7 @@ def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
"PN": csum.hexdigest(),
"AM": ac.AuthnContextClassRef.text
}
logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a,v) for a,v in info]))
logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a, v) for a, v in info]))
def dynamic_importer(name, class_name=None):
@@ -428,14 +428,14 @@ def dynamic_importer(name, class_name=None):
try:
package = imp.load_module(name, fp, pathname, description)
except Exception, e:
except Exception:
raise
if class_name:
try:
_class = imp.load_module("%s.%s" % (name, class_name), fp,
pathname, description)
except Exception, e:
pathname, description)
except Exception:
raise
return package, _class
@@ -452,3 +452,34 @@ def exception_trace(exc):
_exc = "Exception: %s" % exc.message.encode("utf-8", "replace")
return {"message": _exc, "content": "".join(message)}
def rec_factory(cls, **kwargs):
_inst = cls()
for key, val in kwargs.items():
if key in ["text", "lang"]:
setattr(_inst, key, val)
elif key in _inst.c_attributes:
try:
val = str(val)
except Exception:
continue
else:
setattr(_inst, key, val)
elif key in _inst.c_child_order:
for tag, _cls in _inst.c_children.values():
if tag == key:
if isinstance(_cls, list):
_cls = _cls[0]
claim = []
if isinstance(val, list):
for v in val:
claim.append(rec_factory(_cls, **v))
else:
claim.append(rec_factory(_cls, **val))
else:
claim = rec_factory(_cls, **val)
setattr(_inst, key, claim)
break
return _inst

View File

@@ -201,6 +201,18 @@ def valid_unsigned_short(val):
return True
def valid_positive_integer(val):
try:
integer = int(val)
except ValueError:
raise NotValid("positive integer")
if integer > 0:
return True
else:
raise NotValid("positive integer")
def valid_non_negative_integer(val):
try:
integer = int(val)
@@ -269,6 +281,7 @@ VALIDATOR = {
"dateTime": valid_date_time,
"anyURI": valid_any_uri,
"nonNegativeInteger": valid_non_negative_integer,
"PositiveInteger": valid_positive_integer,
"boolean": valid_boolean,
"unsignedShort": valid_unsigned_short,
"duration": valid_duration,