Refactored the support for metadata extension in the config file.
This commit is contained in:
		@@ -23,6 +23,7 @@ import xmldsig as ds
 | 
			
		||||
from saml2.sigver import pre_signature_part
 | 
			
		||||
 | 
			
		||||
from saml2.s_utils import factory
 | 
			
		||||
from saml2.s_utils import rec_factory
 | 
			
		||||
from saml2.s_utils import sid
 | 
			
		||||
 | 
			
		||||
__author__ = 'rolandh'
 | 
			
		||||
@@ -51,6 +52,7 @@ ORG_ATTR_TRANSL = {
 | 
			
		||||
    "organization_url": ("url", md.OrganizationURL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def metadata_tostring_fix(desc, nspair):
 | 
			
		||||
    MDNS = '"urn:oasis:names:tc:SAML:2.0:metadata"'
 | 
			
		||||
    XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
 | 
			
		||||
@@ -60,7 +62,7 @@ def metadata_tostring_fix(desc, nspair):
 | 
			
		||||
    return xmlstring
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
 | 
			
		||||
def create_metadata_string(configfile, config, valid, cert, keyfile, mid, name,
 | 
			
		||||
                           sign):
 | 
			
		||||
    valid_for = 0
 | 
			
		||||
    nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
 | 
			
		||||
@@ -85,8 +87,8 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
 | 
			
		||||
    conf.xmlsec_binary = config.xmlsec_binary
 | 
			
		||||
    secc = security_context(conf)
 | 
			
		||||
 | 
			
		||||
    if id:
 | 
			
		||||
        desc = entities_descriptor(eds, valid_for, name, id,
 | 
			
		||||
    if mid:
 | 
			
		||||
        desc = entities_descriptor(eds, valid_for, name, mid,
 | 
			
		||||
                                   sign, secc)
 | 
			
		||||
        valid_instance(desc)
 | 
			
		||||
 | 
			
		||||
@@ -94,7 +96,7 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name,
 | 
			
		||||
    else:
 | 
			
		||||
        for eid in eds:
 | 
			
		||||
            if sign:
 | 
			
		||||
                desc = sign_entity_descriptor(eid, id, secc)
 | 
			
		||||
                desc = sign_entity_descriptor(eid, mid, secc)
 | 
			
		||||
            else:
 | 
			
		||||
                desc = eid
 | 
			
		||||
            valid_instance(desc)
 | 
			
		||||
@@ -372,6 +374,21 @@ DEFAULT_BINDING = {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def do_extensions(mname, item):
 | 
			
		||||
    try:
 | 
			
		||||
        _mod = __import__("saml2.extension.%s" % mname, globals(), locals(),
 | 
			
		||||
                          mname)
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        return None
 | 
			
		||||
    else:
 | 
			
		||||
        res = []
 | 
			
		||||
 | 
			
		||||
        for _cname, ava in item.items():
 | 
			
		||||
            cls = getattr(_mod, _cname)
 | 
			
		||||
            res.append(rec_factory(cls, **ava))
 | 
			
		||||
    return res
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _do_nameid_format(cls, conf, typ):
 | 
			
		||||
    namef = conf.getattr("name_id_format", typ)
 | 
			
		||||
    if namef:
 | 
			
		||||
@@ -421,19 +438,30 @@ def do_spsso_descriptor(conf, cert=None):
 | 
			
		||||
    spsso = md.SPSSODescriptor()
 | 
			
		||||
    spsso.protocol_support_enumeration = samlp.NAMESPACE
 | 
			
		||||
 | 
			
		||||
    exts = conf.getattr("extensions", "sp")
 | 
			
		||||
    if exts:
 | 
			
		||||
        if spsso.extensions is None:
 | 
			
		||||
            spsso.extensions = md.Extensions()
 | 
			
		||||
 | 
			
		||||
        for key, val in exts.items():
 | 
			
		||||
            _ext = do_extensions(key, val)
 | 
			
		||||
            if _ext:
 | 
			
		||||
                for _e in _ext:
 | 
			
		||||
                    spsso.extensions.add_extension_element(_e)
 | 
			
		||||
 | 
			
		||||
    endps = conf.getattr("endpoints", "sp")
 | 
			
		||||
    if endps:
 | 
			
		||||
        for (endpoint, instlist) in do_endpoints(endps,
 | 
			
		||||
                                                 ENDPOINTS["sp"]).items():
 | 
			
		||||
            setattr(spsso, endpoint, instlist)
 | 
			
		||||
 | 
			
		||||
    ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
 | 
			
		||||
    if ext:
 | 
			
		||||
        if spsso.extensions is None:
 | 
			
		||||
            spsso.extensions = md.Extensions()
 | 
			
		||||
        for vals in ext.values():
 | 
			
		||||
            for val in vals:
 | 
			
		||||
                spsso.extensions.add_extension_element(val)
 | 
			
		||||
    # ext = do_endpoints(endps, ENDPOINT_EXT["sp"])
 | 
			
		||||
    # if ext:
 | 
			
		||||
    #     if spsso.extensions is None:
 | 
			
		||||
    #         spsso.extensions = md.Extensions()
 | 
			
		||||
    #     for vals in ext.values():
 | 
			
		||||
    #         for val in vals:
 | 
			
		||||
    #             spsso.extensions.add_extension_element(val)
 | 
			
		||||
 | 
			
		||||
    if cert:
 | 
			
		||||
        encryption_type = conf.encryption_type
 | 
			
		||||
 
 | 
			
		||||
@@ -413,7 +413,7 @@ def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
 | 
			
		||||
        "PN": csum.hexdigest(),
 | 
			
		||||
        "AM": ac.AuthnContextClassRef.text
 | 
			
		||||
    }
 | 
			
		||||
    logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a,v) for a,v in info]))
 | 
			
		||||
    logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a, v) for a, v in info]))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dynamic_importer(name, class_name=None):
 | 
			
		||||
@@ -428,14 +428,14 @@ def dynamic_importer(name, class_name=None):
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        package = imp.load_module(name, fp, pathname, description)
 | 
			
		||||
    except Exception, e:
 | 
			
		||||
    except Exception:
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
    if class_name:
 | 
			
		||||
        try:
 | 
			
		||||
            _class = imp.load_module("%s.%s" % (name, class_name), fp,
 | 
			
		||||
                                      pathname, description)
 | 
			
		||||
        except Exception, e:
 | 
			
		||||
                                     pathname, description)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
        return package, _class
 | 
			
		||||
@@ -452,3 +452,34 @@ def exception_trace(exc):
 | 
			
		||||
        _exc = "Exception: %s" % exc.message.encode("utf-8", "replace")
 | 
			
		||||
 | 
			
		||||
    return {"message": _exc, "content": "".join(message)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rec_factory(cls, **kwargs):
 | 
			
		||||
    _inst = cls()
 | 
			
		||||
    for key, val in kwargs.items():
 | 
			
		||||
        if key in ["text", "lang"]:
 | 
			
		||||
            setattr(_inst, key, val)
 | 
			
		||||
        elif key in _inst.c_attributes:
 | 
			
		||||
            try:
 | 
			
		||||
                val = str(val)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                continue
 | 
			
		||||
            else:
 | 
			
		||||
                setattr(_inst, key, val)
 | 
			
		||||
        elif key in _inst.c_child_order:
 | 
			
		||||
            for tag, _cls in _inst.c_children.values():
 | 
			
		||||
                if tag == key:
 | 
			
		||||
                    if isinstance(_cls, list):
 | 
			
		||||
                        _cls = _cls[0]
 | 
			
		||||
                        claim = []
 | 
			
		||||
                        if isinstance(val, list):
 | 
			
		||||
                            for v in val:
 | 
			
		||||
                                claim.append(rec_factory(_cls, **v))
 | 
			
		||||
                        else:
 | 
			
		||||
                            claim.append(rec_factory(_cls, **val))
 | 
			
		||||
                    else:
 | 
			
		||||
                        claim = rec_factory(_cls, **val)
 | 
			
		||||
                    setattr(_inst, key, claim)
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
    return _inst
 | 
			
		||||
 
 | 
			
		||||
@@ -201,6 +201,18 @@ def valid_unsigned_short(val):
 | 
			
		||||
    return True
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
def valid_positive_integer(val):
 | 
			
		||||
    try:
 | 
			
		||||
        integer = int(val)
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        raise NotValid("positive integer")
 | 
			
		||||
 | 
			
		||||
    if integer > 0:
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotValid("positive integer")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def valid_non_negative_integer(val):
 | 
			
		||||
    try:
 | 
			
		||||
        integer = int(val)
 | 
			
		||||
@@ -269,6 +281,7 @@ VALIDATOR = {
 | 
			
		||||
    "dateTime": valid_date_time,
 | 
			
		||||
    "anyURI": valid_any_uri,
 | 
			
		||||
    "nonNegativeInteger": valid_non_negative_integer,
 | 
			
		||||
    "PositiveInteger": valid_positive_integer,
 | 
			
		||||
    "boolean": valid_boolean,
 | 
			
		||||
    "unsignedShort": valid_unsigned_short,
 | 
			
		||||
    "duration": valid_duration,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user