From 5cd5ebdc034180ff47767fe9cca4d4eba1bb1f27 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 13 Mar 2014 15:15:20 +0100 Subject: [PATCH] Refactored the support for metadata extension in the config file. --- src/saml2/metadata.py | 50 +++++++++++++++++++++++++++++++++---------- src/saml2/s_utils.py | 39 +++++++++++++++++++++++++++++---- src/saml2/validate.py | 13 +++++++++++ 3 files changed, 87 insertions(+), 15 deletions(-) diff --git a/src/saml2/metadata.py b/src/saml2/metadata.py index acdeac8..7c2efb1 100644 --- a/src/saml2/metadata.py +++ b/src/saml2/metadata.py @@ -23,6 +23,7 @@ import xmldsig as ds from saml2.sigver import pre_signature_part from saml2.s_utils import factory +from saml2.s_utils import rec_factory from saml2.s_utils import sid __author__ = 'rolandh' @@ -51,6 +52,7 @@ ORG_ATTR_TRANSL = { "organization_url": ("url", md.OrganizationURL) } + def metadata_tostring_fix(desc, nspair): MDNS = '"urn:oasis:names:tc:SAML:2.0:metadata"' XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\"" @@ -60,7 +62,7 @@ def metadata_tostring_fix(desc, nspair): return xmlstring -def create_metadata_string(configfile, config, valid, cert, keyfile, id, name, +def create_metadata_string(configfile, config, valid, cert, keyfile, mid, name, sign): valid_for = 0 nspair = {"xs": "http://www.w3.org/2001/XMLSchema"} @@ -85,8 +87,8 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name, conf.xmlsec_binary = config.xmlsec_binary secc = security_context(conf) - if id: - desc = entities_descriptor(eds, valid_for, name, id, + if mid: + desc = entities_descriptor(eds, valid_for, name, mid, sign, secc) valid_instance(desc) @@ -94,7 +96,7 @@ def create_metadata_string(configfile, config, valid, cert, keyfile, id, name, else: for eid in eds: if sign: - desc = sign_entity_descriptor(eid, id, secc) + desc = sign_entity_descriptor(eid, mid, secc) else: desc = eid valid_instance(desc) @@ -372,6 +374,21 @@ DEFAULT_BINDING = { } +def do_extensions(mname, item): + try: + _mod = __import__("saml2.extension.%s" % mname, globals(), locals(), + mname) + except ImportError: + return None + else: + res = [] + + for _cname, ava in item.items(): + cls = getattr(_mod, _cname) + res.append(rec_factory(cls, **ava)) + return res + + def _do_nameid_format(cls, conf, typ): namef = conf.getattr("name_id_format", typ) if namef: @@ -421,19 +438,30 @@ def do_spsso_descriptor(conf, cert=None): spsso = md.SPSSODescriptor() spsso.protocol_support_enumeration = samlp.NAMESPACE + exts = conf.getattr("extensions", "sp") + if exts: + if spsso.extensions is None: + spsso.extensions = md.Extensions() + + for key, val in exts.items(): + _ext = do_extensions(key, val) + if _ext: + for _e in _ext: + spsso.extensions.add_extension_element(_e) + endps = conf.getattr("endpoints", "sp") if endps: for (endpoint, instlist) in do_endpoints(endps, ENDPOINTS["sp"]).items(): setattr(spsso, endpoint, instlist) - ext = do_endpoints(endps, ENDPOINT_EXT["sp"]) - if ext: - if spsso.extensions is None: - spsso.extensions = md.Extensions() - for vals in ext.values(): - for val in vals: - spsso.extensions.add_extension_element(val) + # ext = do_endpoints(endps, ENDPOINT_EXT["sp"]) + # if ext: + # if spsso.extensions is None: + # spsso.extensions = md.Extensions() + # for vals in ext.values(): + # for val in vals: + # spsso.extensions.add_extension_element(val) if cert: encryption_type = conf.encryption_type diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py index 0343467..57fe088 100644 --- a/src/saml2/s_utils.py +++ b/src/saml2/s_utils.py @@ -413,7 +413,7 @@ def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion): "PN": csum.hexdigest(), "AM": ac.AuthnContextClassRef.text } - logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a,v) for a,v in info])) + logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a, v) for a, v in info])) def dynamic_importer(name, class_name=None): @@ -428,14 +428,14 @@ def dynamic_importer(name, class_name=None): try: package = imp.load_module(name, fp, pathname, description) - except Exception, e: + except Exception: raise if class_name: try: _class = imp.load_module("%s.%s" % (name, class_name), fp, - pathname, description) - except Exception, e: + pathname, description) + except Exception: raise return package, _class @@ -452,3 +452,34 @@ def exception_trace(exc): _exc = "Exception: %s" % exc.message.encode("utf-8", "replace") return {"message": _exc, "content": "".join(message)} + + +def rec_factory(cls, **kwargs): + _inst = cls() + for key, val in kwargs.items(): + if key in ["text", "lang"]: + setattr(_inst, key, val) + elif key in _inst.c_attributes: + try: + val = str(val) + except Exception: + continue + else: + setattr(_inst, key, val) + elif key in _inst.c_child_order: + for tag, _cls in _inst.c_children.values(): + if tag == key: + if isinstance(_cls, list): + _cls = _cls[0] + claim = [] + if isinstance(val, list): + for v in val: + claim.append(rec_factory(_cls, **v)) + else: + claim.append(rec_factory(_cls, **val)) + else: + claim = rec_factory(_cls, **val) + setattr(_inst, key, claim) + break + + return _inst diff --git a/src/saml2/validate.py b/src/saml2/validate.py index 3ba2dff..376df9e 100644 --- a/src/saml2/validate.py +++ b/src/saml2/validate.py @@ -201,6 +201,18 @@ def valid_unsigned_short(val): return True +def valid_positive_integer(val): + try: + integer = int(val) + except ValueError: + raise NotValid("positive integer") + + if integer > 0: + return True + else: + raise NotValid("positive integer") + + def valid_non_negative_integer(val): try: integer = int(val) @@ -269,6 +281,7 @@ VALIDATOR = { "dateTime": valid_date_time, "anyURI": valid_any_uri, "nonNegativeInteger": valid_non_negative_integer, + "PositiveInteger": valid_positive_integer, "boolean": valid_boolean, "unsignedShort": valid_unsigned_short, "duration": valid_duration,