From 9994d2646481d63b3c207078b2bc034348b93357 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Wed, 27 May 2015 11:49:32 +0200 Subject: [PATCH] Allow for filtering which entities you want to retain from the metadata. --- src/saml2/filter.py | 35 +++++++++++++++++ tests/test_38_metadata_filter.py | 66 ++++++++++++++++++++++++++++++++ tests/test_75_mongodb.py | 36 +++++++++-------- 3 files changed, 121 insertions(+), 16 deletions(-) create mode 100644 src/saml2/filter.py create mode 100644 tests/test_38_metadata_filter.py diff --git a/src/saml2/filter.py b/src/saml2/filter.py new file mode 100644 index 0000000..0be7a32 --- /dev/null +++ b/src/saml2/filter.py @@ -0,0 +1,35 @@ +__author__ = 'roland' + +class Filter(object): + def __init__(self): + pass + + def __call__(self, *args, **kwargs): + pass + + +class AllowDescriptor(Filter): + def __init__(self, allow): + """ + + :param allow: List of allowed descriptors + :return: + """ + super(AllowDescriptor, self).__init__() + self.allow = allow + + def __call__(self, entity_descriptor): + # get descriptors + _all = [] + for desc in entity_descriptor.keys(): + if desc.endswith("_descriptor"): + typ, _ = desc.rsplit("_", 1) + if typ in self.allow: + _all.append(typ) + else: + del entity_descriptor[desc] + + if not _all: + return None + else: + return entity_descriptor diff --git a/tests/test_38_metadata_filter.py b/tests/test_38_metadata_filter.py new file mode 100644 index 0000000..f23d49c --- /dev/null +++ b/tests/test_38_metadata_filter.py @@ -0,0 +1,66 @@ +from saml2 import md +from saml2 import saml +from saml2 import config +from saml2 import xmldsig +from saml2 import xmlenc + +from saml2.filter import AllowDescriptor +from saml2.mdstore import MetadataStore +from saml2.attribute_converter import ac_factory +from saml2.extension import mdui +from saml2.extension import idpdisc +from saml2.extension import dri +from saml2.extension import mdattr +from saml2.extension import ui + +from pathutils import full_path + +__author__ = 'roland' + +sec_config = config.Config() + +ONTS = { + saml.NAMESPACE: saml, + mdui.NAMESPACE: mdui, + mdattr.NAMESPACE: mdattr, + dri.NAMESPACE: dri, + ui.NAMESPACE: ui, + idpdisc.NAMESPACE: idpdisc, + md.NAMESPACE: md, + xmldsig.NAMESPACE: xmldsig, + xmlenc.NAMESPACE: xmlenc +} + +ATTRCONV = ac_factory(full_path("attributemaps")) + +METADATACONF = { + "1": [{ + "class": "saml2.mdstore.MetaDataFile", + "metadata": [(full_path("swamid-2.0.xml"), )], + }], +} + +def test_swamid_sp(): + mds = MetadataStore(ONTS.values(), ATTRCONV, sec_config, + disable_ssl_certificate_validation=True, + filter=AllowDescriptor(["spsso"])) + + mds.imp(METADATACONF["1"]) + sps = mds.with_descriptor("spsso") + assert len(sps) == 417 + idps = mds.with_descriptor("idpsso") + assert idps == {} + +def test_swamid_idp(): + mds = MetadataStore(ONTS.values(), ATTRCONV, sec_config, + disable_ssl_certificate_validation=True, + filter=AllowDescriptor(["idpsso"])) + + mds.imp(METADATACONF["1"]) + sps = mds.with_descriptor("spsso") + assert len(sps) == 0 + idps = mds.with_descriptor("idpsso") + assert len(idps) == 275 + +if __name__ == "__main__": + test_swamid_idp() diff --git a/tests/test_75_mongodb.py b/tests/test_75_mongodb.py index f79baff..9d9893c 100644 --- a/tests/test_75_mongodb.py +++ b/tests/test_75_mongodb.py @@ -1,5 +1,5 @@ from contextlib import closing -from pymongo.errors import ConnectionFailure +from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError from saml2 import BINDING_HTTP_POST from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.client import Saml2Client @@ -69,23 +69,27 @@ def test_eptid_mongo_db(): except ConnectionFailure: pass else: - e1 = edb.get("idp_entity_id", "sp_entity_id", "user_id", - "some other data") - print(e1) - assert e1.startswith("idp_entity_id!sp_entity_id!") - e2 = edb.get("idp_entity_id", "sp_entity_id", "user_id", - "some other data") - assert e1 == e2 + try: + e1 = edb.get("idp_entity_id", "sp_entity_id", "user_id", + "some other data") + except ServerSelectionTimeoutError: + pass + else: + print(e1) + assert e1.startswith("idp_entity_id!sp_entity_id!") + e2 = edb.get("idp_entity_id", "sp_entity_id", "user_id", + "some other data") + assert e1 == e2 - e3 = edb.get("idp_entity_id", "sp_entity_id", "user_2", - "some other data") - print(e3) - assert e1 != e3 + e3 = edb.get("idp_entity_id", "sp_entity_id", "user_2", + "some other data") + print(e3) + assert e1 != e3 - e4 = edb.get("idp_entity_id", "sp_entity_id2", "user_id", - "some other data") - assert e4 != e1 - assert e4 != e3 + e4 = edb.get("idp_entity_id", "sp_entity_id2", "user_id", + "some other data") + assert e4 != e1 + assert e4 != e3