diff --git a/tools/make_metadata.py b/tools/make_metadata.py index e889d70..3737941 100755 --- a/tools/make_metadata.py +++ b/tools/make_metadata.py @@ -7,6 +7,7 @@ from saml2.time_util import in_a_while from saml2.utils import parse_attribute_map, args2dict from saml2.saml import NAME_FORMAT_URI from saml2.sigver import pre_signature_part, SecurityContext +from saml2.attribute_converter import from_local_name, ac_factory HELP_MESSAGE = """ Usage: make_metadata [options] 1*configurationfile @@ -87,7 +88,7 @@ def do_contact_person_info(conf, desc): pass desc["contact_person"].append(dorg) -def do_sp_sso_descriptor(sp, cert, backward_map): +def do_sp_sso_descriptor(sp, cert, acs): desc = { "protocol_support_enumeration": samlp.NAMESPACE, "assertion_consumer_service": { @@ -113,35 +114,14 @@ def do_sp_sso_descriptor(sp, cert, backward_map): requested_attribute = [] if "required_attributes" in sp: for attr in sp["required_attributes"]: - try: - requested_attribute.append({ - "is_required": "true", - "friendly_name": attr, - "name_format": NAME_FORMAT_URI, - "name": backward_map[attr][0] - }) - except KeyError: - requested_attribute.append({ - "is_required": "true", - "friendly_name": attr, - "name_format": NAME_FORMAT_URI, - "name": attr - }) + reqa = from_local_name(acs, attr, NAME_FORMAT_URI) + reqa["is_required"] = "true" + requested_attribute.append(reqa) if "optional_attributes" in sp: for attr in sp["optional_attributes"]: - try: - requested_attribute.append({ - "friendly_name": attr, - "name_format": NAME_FORMAT_URI, - "name": backward_map[attr][0] - }) - except KeyError: - requested_attribute.append({ - "friendly_name": attr, - "name_format": NAME_FORMAT_URI, - "name": attr - }) + reqa = from_local_name(acs, attr, NAME_FORMAT_URI) + requested_attribute.append(reqa) if requested_attribute: desc["attribute_consuming_service"] = { @@ -203,10 +183,15 @@ def do_aa_descriptor(aa, cert): def entity_descriptor(confd, valid_for): mycert = "".join(open(confd["cert_file"]).readlines()[1:-1]) - if "attribute_maps" in confd: - (forward,backward) = parse_attribute_map(confd["attribute_maps"]) + if "attribute_map_dir" in confd: + attrconverters = ac_factory(confd["attribute_map_dir"]) else: - backward = {} + attrconverters = [AttributeConverter()] + + #if "attribute_maps" in confd: + # (forward,backward) = parse_attribute_map(confd["attribute_maps"]) + #else: + # backward = {} ed = { "entity_id": confd["entityid"], @@ -220,7 +205,7 @@ def entity_descriptor(confd, valid_for): if "sp" in confd["service"]: # The SP ed["sp_sso_descriptor"] = do_sp_sso_descriptor(confd["service"]["sp"], - mycert, backward) + mycert, attrconverters) if "idp" in confd["service"]: ed["idp_sso_descriptor"] = do_idp_sso_descriptor( confd["service"]["idp"], mycert)