Refactored the support for metadata extension in the config file.
This commit is contained in:
@@ -23,6 +23,7 @@ import xmldsig as ds
|
|||||||
from saml2.sigver import pre_signature_part
|
from saml2.sigver import pre_signature_part
|
||||||
|
|
||||||
from saml2.s_utils import factory
|
from saml2.s_utils import factory
|
||||||
|
from saml2.s_utils import rec_factory
|
||||||
from saml2.s_utils import sid
|
from saml2.s_utils import sid
|
||||||
|
|
||||||
__author__ = 'rolandh'
|
__author__ = 'rolandh'
|
||||||
@@ -51,6 +52,7 @@ ORG_ATTR_TRANSL = {
|
|||||||
"organization_url": ("url", md.OrganizationURL)
|
"organization_url": ("url", md.OrganizationURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def metadata_tostring_fix(desc, nspair):
|
def metadata_tostring_fix(desc, nspair):
|
||||||
MDNS = '"urn:oasis:names:tc:SAML:2.0:metadata"'
|
MDNS = '"urn:oasis:names:tc:SAML:2.0:metadata"'
|
||||||
XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
|
XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
|
||||||
@@ -60,7 +62,7 @@ def metadata_tostring_fix(desc, nspair):
|
|||||||
return xmlstring
|
return xmlstring
|
||||||
|
|
||||||
|
|
||||||
def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
|
def create_metadata_string(configfile, config, valid, cert, keyfile, mid, name,
|
||||||
sign):
|
sign):
|
||||||
valid_for = 0
|
valid_for = 0
|
||||||
nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
|
nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
|
||||||
@@ -85,8 +87,8 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
|
|||||||
conf.xmlsec_binary = config.xmlsec_binary
|
conf.xmlsec_binary = config.xmlsec_binary
|
||||||
secc = security_context(conf)
|
secc = security_context(conf)
|
||||||
|
|
||||||
if id:
|
if mid:
|
||||||
desc = entities_descriptor(eds, valid_for, name, id,
|
desc = entities_descriptor(eds, valid_for, name, mid,
|
||||||
sign, secc)
|
sign, secc)
|
||||||
valid_instance(desc)
|
valid_instance(desc)
|
||||||
|
|
||||||
@@ -94,7 +96,7 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
|
|||||||
else:
|
else:
|
||||||
for eid in eds:
|
for eid in eds:
|
||||||
if sign:
|
if sign:
|
||||||
desc = sign_entity_descriptor(eid, id, secc)
|
desc = sign_entity_descriptor(eid, mid, secc)
|
||||||
else:
|
else:
|
||||||
desc = eid
|
desc = eid
|
||||||
valid_instance(desc)
|
valid_instance(desc)
|
||||||
@@ -372,6 +374,21 @@ DEFAULT_BINDING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def do_extensions(mname, item):
|
||||||
|
try:
|
||||||
|
_mod = __import__("saml2.extension.%s" % mname, globals(), locals(),
|
||||||
|
mname)
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for _cname, ava in item.items():
|
||||||
|
cls = getattr(_mod, _cname)
|
||||||
|
res.append(rec_factory(cls, **ava))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _do_nameid_format(cls, conf, typ):
|
def _do_nameid_format(cls, conf, typ):
|
||||||
namef = conf.getattr("name_id_format", typ)
|
namef = conf.getattr("name_id_format", typ)
|
||||||
if namef:
|
if namef:
|
||||||
@@ -421,19 +438,30 @@ def do_spsso_descriptor(conf, cert=None):
|
|||||||
spsso = md.SPSSODescriptor()
|
spsso = md.SPSSODescriptor()
|
||||||
spsso.protocol_support_enumeration = samlp.NAMESPACE
|
spsso.protocol_support_enumeration = samlp.NAMESPACE
|
||||||
|
|
||||||
|
exts = conf.getattr("extensions", "sp")
|
||||||
|
if exts:
|
||||||
|
if spsso.extensions is None:
|
||||||
|
spsso.extensions = md.Extensions()
|
||||||
|
|
||||||
|
for key, val in exts.items():
|
||||||
|
_ext = do_extensions(key, val)
|
||||||
|
if _ext:
|
||||||
|
for _e in _ext:
|
||||||
|
spsso.extensions.add_extension_element(_e)
|
||||||
|
|
||||||
endps = conf.getattr("endpoints", "sp")
|
endps = conf.getattr("endpoints", "sp")
|
||||||
if endps:
|
if endps:
|
||||||
for (endpoint, instlist) in do_endpoints(endps,
|
for (endpoint, instlist) in do_endpoints(endps,
|
||||||
ENDPOINTS["sp"]).items():
|
ENDPOINTS["sp"]).items():
|
||||||
setattr(spsso, endpoint, instlist)
|
setattr(spsso, endpoint, instlist)
|
||||||
|
|
||||||
ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
|
# ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
|
||||||
if ext:
|
# if ext:
|
||||||
if spsso.extensions is None:
|
# if spsso.extensions is None:
|
||||||
spsso.extensions = md.Extensions()
|
# spsso.extensions = md.Extensions()
|
||||||
for vals in ext.values():
|
# for vals in ext.values():
|
||||||
for val in vals:
|
# for val in vals:
|
||||||
spsso.extensions.add_extension_element(val)
|
# spsso.extensions.add_extension_element(val)
|
||||||
|
|
||||||
if cert:
|
if cert:
|
||||||
encryption_type = conf.encryption_type
|
encryption_type = conf.encryption_type
|
||||||
|
|||||||
@@ -413,7 +413,7 @@ def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
|
|||||||
"PN": csum.hexdigest(),
|
"PN": csum.hexdigest(),
|
||||||
"AM": ac.AuthnContextClassRef.text
|
"AM": ac.AuthnContextClassRef.text
|
||||||
}
|
}
|
||||||
logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a,v) for a,v in info]))
|
logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a, v) for a, v in info]))
|
||||||
|
|
||||||
|
|
||||||
def dynamic_importer(name, class_name=None):
|
def dynamic_importer(name, class_name=None):
|
||||||
@@ -428,14 +428,14 @@ def dynamic_importer(name, class_name=None):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
package = imp.load_module(name, fp, pathname, description)
|
package = imp.load_module(name, fp, pathname, description)
|
||||||
except Exception, e:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if class_name:
|
if class_name:
|
||||||
try:
|
try:
|
||||||
_class = imp.load_module("%s.%s" % (name, class_name), fp,
|
_class = imp.load_module("%s.%s" % (name, class_name), fp,
|
||||||
pathname, description)
|
pathname, description)
|
||||||
except Exception, e:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return package, _class
|
return package, _class
|
||||||
@@ -452,3 +452,34 @@ def exception_trace(exc):
|
|||||||
_exc = "Exception: %s" % exc.message.encode("utf-8", "replace")
|
_exc = "Exception: %s" % exc.message.encode("utf-8", "replace")
|
||||||
|
|
||||||
return {"message": _exc, "content": "".join(message)}
|
return {"message": _exc, "content": "".join(message)}
|
||||||
|
|
||||||
|
|
||||||
|
def rec_factory(cls, **kwargs):
|
||||||
|
_inst = cls()
|
||||||
|
for key, val in kwargs.items():
|
||||||
|
if key in ["text", "lang"]:
|
||||||
|
setattr(_inst, key, val)
|
||||||
|
elif key in _inst.c_attributes:
|
||||||
|
try:
|
||||||
|
val = str(val)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
setattr(_inst, key, val)
|
||||||
|
elif key in _inst.c_child_order:
|
||||||
|
for tag, _cls in _inst.c_children.values():
|
||||||
|
if tag == key:
|
||||||
|
if isinstance(_cls, list):
|
||||||
|
_cls = _cls[0]
|
||||||
|
claim = []
|
||||||
|
if isinstance(val, list):
|
||||||
|
for v in val:
|
||||||
|
claim.append(rec_factory(_cls, **v))
|
||||||
|
else:
|
||||||
|
claim.append(rec_factory(_cls, **val))
|
||||||
|
else:
|
||||||
|
claim = rec_factory(_cls, **val)
|
||||||
|
setattr(_inst, key, claim)
|
||||||
|
break
|
||||||
|
|
||||||
|
return _inst
|
||||||
|
|||||||
@@ -201,6 +201,18 @@ def valid_unsigned_short(val):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def valid_positive_integer(val):
|
||||||
|
try:
|
||||||
|
integer = int(val)
|
||||||
|
except ValueError:
|
||||||
|
raise NotValid("positive integer")
|
||||||
|
|
||||||
|
if integer > 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise NotValid("positive integer")
|
||||||
|
|
||||||
|
|
||||||
def valid_non_negative_integer(val):
|
def valid_non_negative_integer(val):
|
||||||
try:
|
try:
|
||||||
integer = int(val)
|
integer = int(val)
|
||||||
@@ -269,6 +281,7 @@ VALIDATOR = {
|
|||||||
"dateTime": valid_date_time,
|
"dateTime": valid_date_time,
|
||||||
"anyURI": valid_any_uri,
|
"anyURI": valid_any_uri,
|
||||||
"nonNegativeInteger": valid_non_negative_integer,
|
"nonNegativeInteger": valid_non_negative_integer,
|
||||||
|
"PositiveInteger": valid_positive_integer,
|
||||||
"boolean": valid_boolean,
|
"boolean": valid_boolean,
|
||||||
"unsignedShort": valid_unsigned_short,
|
"unsignedShort": valid_unsigned_short,
|
||||||
"duration": valid_duration,
|
"duration": valid_duration,
|
||||||
|
|||||||
Reference in New Issue
Block a user