Allow for filtering which entities you want to retain from the metadata.
This commit is contained in:
35
src/saml2/filter.py
Normal file
35
src/saml2/filter.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
__author__ = 'roland'
|
||||||
|
|
||||||
|
class Filter(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AllowDescriptor(Filter):
|
||||||
|
def __init__(self, allow):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param allow: List of allowed descriptors
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
super(AllowDescriptor, self).__init__()
|
||||||
|
self.allow = allow
|
||||||
|
|
||||||
|
def __call__(self, entity_descriptor):
|
||||||
|
# get descriptors
|
||||||
|
_all = []
|
||||||
|
for desc in entity_descriptor.keys():
|
||||||
|
if desc.endswith("_descriptor"):
|
||||||
|
typ, _ = desc.rsplit("_", 1)
|
||||||
|
if typ in self.allow:
|
||||||
|
_all.append(typ)
|
||||||
|
else:
|
||||||
|
del entity_descriptor[desc]
|
||||||
|
|
||||||
|
if not _all:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return entity_descriptor
|
66
tests/test_38_metadata_filter.py
Normal file
66
tests/test_38_metadata_filter.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from saml2 import md
|
||||||
|
from saml2 import saml
|
||||||
|
from saml2 import config
|
||||||
|
from saml2 import xmldsig
|
||||||
|
from saml2 import xmlenc
|
||||||
|
|
||||||
|
from saml2.filter import AllowDescriptor
|
||||||
|
from saml2.mdstore import MetadataStore
|
||||||
|
from saml2.attribute_converter import ac_factory
|
||||||
|
from saml2.extension import mdui
|
||||||
|
from saml2.extension import idpdisc
|
||||||
|
from saml2.extension import dri
|
||||||
|
from saml2.extension import mdattr
|
||||||
|
from saml2.extension import ui
|
||||||
|
|
||||||
|
from pathutils import full_path
|
||||||
|
|
||||||
|
__author__ = 'roland'
|
||||||
|
|
||||||
|
sec_config = config.Config()
|
||||||
|
|
||||||
|
ONTS = {
|
||||||
|
saml.NAMESPACE: saml,
|
||||||
|
mdui.NAMESPACE: mdui,
|
||||||
|
mdattr.NAMESPACE: mdattr,
|
||||||
|
dri.NAMESPACE: dri,
|
||||||
|
ui.NAMESPACE: ui,
|
||||||
|
idpdisc.NAMESPACE: idpdisc,
|
||||||
|
md.NAMESPACE: md,
|
||||||
|
xmldsig.NAMESPACE: xmldsig,
|
||||||
|
xmlenc.NAMESPACE: xmlenc
|
||||||
|
}
|
||||||
|
|
||||||
|
ATTRCONV = ac_factory(full_path("attributemaps"))
|
||||||
|
|
||||||
|
METADATACONF = {
|
||||||
|
"1": [{
|
||||||
|
"class": "saml2.mdstore.MetaDataFile",
|
||||||
|
"metadata": [(full_path("swamid-2.0.xml"), )],
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_swamid_sp():
|
||||||
|
mds = MetadataStore(ONTS.values(), ATTRCONV, sec_config,
|
||||||
|
disable_ssl_certificate_validation=True,
|
||||||
|
filter=AllowDescriptor(["spsso"]))
|
||||||
|
|
||||||
|
mds.imp(METADATACONF["1"])
|
||||||
|
sps = mds.with_descriptor("spsso")
|
||||||
|
assert len(sps) == 417
|
||||||
|
idps = mds.with_descriptor("idpsso")
|
||||||
|
assert idps == {}
|
||||||
|
|
||||||
|
def test_swamid_idp():
|
||||||
|
mds = MetadataStore(ONTS.values(), ATTRCONV, sec_config,
|
||||||
|
disable_ssl_certificate_validation=True,
|
||||||
|
filter=AllowDescriptor(["idpsso"]))
|
||||||
|
|
||||||
|
mds.imp(METADATACONF["1"])
|
||||||
|
sps = mds.with_descriptor("spsso")
|
||||||
|
assert len(sps) == 0
|
||||||
|
idps = mds.with_descriptor("idpsso")
|
||||||
|
assert len(idps) == 275
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_swamid_idp()
|
@@ -1,5 +1,5 @@
|
|||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from pymongo.errors import ConnectionFailure
|
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
||||||
from saml2 import BINDING_HTTP_POST
|
from saml2 import BINDING_HTTP_POST
|
||||||
from saml2.authn_context import INTERNETPROTOCOLPASSWORD
|
from saml2.authn_context import INTERNETPROTOCOLPASSWORD
|
||||||
from saml2.client import Saml2Client
|
from saml2.client import Saml2Client
|
||||||
@@ -69,23 +69,27 @@ def test_eptid_mongo_db():
|
|||||||
except ConnectionFailure:
|
except ConnectionFailure:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
e1 = edb.get("idp_entity_id", "sp_entity_id", "user_id",
|
try:
|
||||||
"some other data")
|
e1 = edb.get("idp_entity_id", "sp_entity_id", "user_id",
|
||||||
print(e1)
|
"some other data")
|
||||||
assert e1.startswith("idp_entity_id!sp_entity_id!")
|
except ServerSelectionTimeoutError:
|
||||||
e2 = edb.get("idp_entity_id", "sp_entity_id", "user_id",
|
pass
|
||||||
"some other data")
|
else:
|
||||||
assert e1 == e2
|
print(e1)
|
||||||
|
assert e1.startswith("idp_entity_id!sp_entity_id!")
|
||||||
|
e2 = edb.get("idp_entity_id", "sp_entity_id", "user_id",
|
||||||
|
"some other data")
|
||||||
|
assert e1 == e2
|
||||||
|
|
||||||
e3 = edb.get("idp_entity_id", "sp_entity_id", "user_2",
|
e3 = edb.get("idp_entity_id", "sp_entity_id", "user_2",
|
||||||
"some other data")
|
"some other data")
|
||||||
print(e3)
|
print(e3)
|
||||||
assert e1 != e3
|
assert e1 != e3
|
||||||
|
|
||||||
e4 = edb.get("idp_entity_id", "sp_entity_id2", "user_id",
|
e4 = edb.get("idp_entity_id", "sp_entity_id2", "user_id",
|
||||||
"some other data")
|
"some other data")
|
||||||
assert e4 != e1
|
assert e4 != e1
|
||||||
assert e4 != e3
|
assert e4 != e3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user