Refactored the support for metadata extension in the config file.

This commit is contained in:
Roland Hedberg
2014-03-13 15:15:20 +01:00
parent d7f50d4b05
commit 5cd5ebdc03
3 changed files with 87 additions and 15 deletions

View File

@@ -23,6 +23,7 @@ import xmldsig as ds
from saml2.sigver import pre_signature_part from saml2.sigver import pre_signature_part
from saml2.s_utils import factory from saml2.s_utils import factory
from saml2.s_utils import rec_factory
from saml2.s_utils import sid from saml2.s_utils import sid
__author__ = 'rolandh' __author__ = 'rolandh'
@@ -51,6 +52,7 @@ ORG_ATTR_TRANSL = {
"organization_url": ("url", md.OrganizationURL) "organization_url": ("url", md.OrganizationURL)
} }
def metadata_tostring_fix(desc, nspair): def metadata_tostring_fix(desc, nspair):
MDNS = '"urn:oasis:names:tc:SAML:2.0:metadata"' MDNS = '"urn:oasis:names:tc:SAML:2.0:metadata"'
XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\"" XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
@@ -60,7 +62,7 @@ def metadata_tostring_fix(desc, nspair):
return xmlstring return xmlstring
def create_metadata_string(configfile, config, valid, cert, keyfile, id, name, def create_metadata_string(configfile, config, valid, cert, keyfile, mid, name,
sign): sign):
valid_for = 0 valid_for = 0
nspair = {"xs": "http://www.w3.org/2001/XMLSchema"} nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
@@ -85,8 +87,8 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
conf.xmlsec_binary = config.xmlsec_binary conf.xmlsec_binary = config.xmlsec_binary
secc = security_context(conf) secc = security_context(conf)
if id: if mid:
desc = entities_descriptor(eds, valid_for, name, id, desc = entities_descriptor(eds, valid_for, name, mid,
sign, secc) sign, secc)
valid_instance(desc) valid_instance(desc)
@@ -94,7 +96,7 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
else: else:
for eid in eds: for eid in eds:
if sign: if sign:
desc = sign_entity_descriptor(eid, id, secc) desc = sign_entity_descriptor(eid, mid, secc)
else: else:
desc = eid desc = eid
valid_instance(desc) valid_instance(desc)
@@ -372,6 +374,21 @@ DEFAULT_BINDING = {
} }
def do_extensions(mname, item):
try:
_mod = __import__("saml2.extension.%s" % mname, globals(), locals(),
mname)
except ImportError:
return None
else:
res = []
for _cname, ava in item.items():
cls = getattr(_mod, _cname)
res.append(rec_factory(cls, **ava))
return res
def _do_nameid_format(cls, conf, typ): def _do_nameid_format(cls, conf, typ):
namef = conf.getattr("name_id_format", typ) namef = conf.getattr("name_id_format", typ)
if namef: if namef:
@@ -421,19 +438,30 @@ def do_spsso_descriptor(conf, cert=None):
spsso = md.SPSSODescriptor() spsso = md.SPSSODescriptor()
spsso.protocol_support_enumeration = samlp.NAMESPACE spsso.protocol_support_enumeration = samlp.NAMESPACE
exts = conf.getattr("extensions", "sp")
if exts:
if spsso.extensions is None:
spsso.extensions = md.Extensions()
for key, val in exts.items():
_ext = do_extensions(key, val)
if _ext:
for _e in _ext:
spsso.extensions.add_extension_element(_e)
endps = conf.getattr("endpoints", "sp") endps = conf.getattr("endpoints", "sp")
if endps: if endps:
for (endpoint, instlist) in do_endpoints(endps, for (endpoint, instlist) in do_endpoints(endps,
ENDPOINTS["sp"]).items(): ENDPOINTS["sp"]).items():
setattr(spsso, endpoint, instlist) setattr(spsso, endpoint, instlist)
ext = do_endpoints(endps, ENDPOINT_EXT["sp"]) # ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
if ext: # if ext:
if spsso.extensions is None: # if spsso.extensions is None:
spsso.extensions = md.Extensions() # spsso.extensions = md.Extensions()
for vals in ext.values(): # for vals in ext.values():
for val in vals: # for val in vals:
spsso.extensions.add_extension_element(val) # spsso.extensions.add_extension_element(val)
if cert: if cert:
encryption_type = conf.encryption_type encryption_type = conf.encryption_type

View File

@@ -413,7 +413,7 @@ def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
"PN": csum.hexdigest(), "PN": csum.hexdigest(),
"AM": ac.AuthnContextClassRef.text "AM": ac.AuthnContextClassRef.text
} }
logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a,v) for a,v in info])) logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a, v) for a, v in info]))
def dynamic_importer(name, class_name=None): def dynamic_importer(name, class_name=None):
@@ -428,14 +428,14 @@ def dynamic_importer(name, class_name=None):
try: try:
package = imp.load_module(name, fp, pathname, description) package = imp.load_module(name, fp, pathname, description)
except Exception, e: except Exception:
raise raise
if class_name: if class_name:
try: try:
_class = imp.load_module("%s.%s" % (name, class_name), fp, _class = imp.load_module("%s.%s" % (name, class_name), fp,
pathname, description) pathname, description)
except Exception, e: except Exception:
raise raise
return package, _class return package, _class
@@ -452,3 +452,34 @@ def exception_trace(exc):
_exc = "Exception: %s" % exc.message.encode("utf-8", "replace") _exc = "Exception: %s" % exc.message.encode("utf-8", "replace")
return {"message": _exc, "content": "".join(message)} return {"message": _exc, "content": "".join(message)}
def rec_factory(cls, **kwargs):
_inst = cls()
for key, val in kwargs.items():
if key in ["text", "lang"]:
setattr(_inst, key, val)
elif key in _inst.c_attributes:
try:
val = str(val)
except Exception:
continue
else:
setattr(_inst, key, val)
elif key in _inst.c_child_order:
for tag, _cls in _inst.c_children.values():
if tag == key:
if isinstance(_cls, list):
_cls = _cls[0]
claim = []
if isinstance(val, list):
for v in val:
claim.append(rec_factory(_cls, **v))
else:
claim.append(rec_factory(_cls, **val))
else:
claim = rec_factory(_cls, **val)
setattr(_inst, key, claim)
break
return _inst

View File

@@ -201,6 +201,18 @@ def valid_unsigned_short(val):
return True return True
def valid_positive_integer(val):
try:
integer = int(val)
except ValueError:
raise NotValid("positive integer")
if integer > 0:
return True
else:
raise NotValid("positive integer")
def valid_non_negative_integer(val): def valid_non_negative_integer(val):
try: try:
integer = int(val) integer = int(val)
@@ -269,6 +281,7 @@ VALIDATOR = {
"dateTime": valid_date_time, "dateTime": valid_date_time,
"anyURI": valid_any_uri, "anyURI": valid_any_uri,
"nonNegativeInteger": valid_non_negative_integer, "nonNegativeInteger": valid_non_negative_integer,
"PositiveInteger": valid_positive_integer,
"boolean": valid_boolean, "boolean": valid_boolean,
"unsignedShort": valid_unsigned_short, "unsignedShort": valid_unsigned_short,
"duration": valid_duration, "duration": valid_duration,