From 7b41b464c5bf36206353bb6c8665473bbeadc2e5 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sat, 8 Mar 2014 19:40:54 +0100 Subject: [PATCH] Refactored class methods --- example/sp-wsgi/sp.py | 2 +- src/saml2/assertion.py | 201 +++++++++++++++----------------- src/saml2/extension/__init__.py | 3 +- 3 files changed, 97 insertions(+), 109 deletions(-) diff --git a/example/sp-wsgi/sp.py b/example/sp-wsgi/sp.py index 53137e7..c6056fe 100755 --- a/example/sp-wsgi/sp.py +++ b/example/sp-wsgi/sp.py @@ -362,7 +362,7 @@ class ACS(Service): def verify_attributes(self, ava): logger.info("SP: %s" % self.sp.config.entityid) - rest = POLICY.get_entity_categories_restriction( + rest = POLICY.get_entity_categories( self.sp.config.entityid, self.sp.metadata) akeys = [k.lower() for k in ava.keys()] diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index 4073b6b..a0bccd6 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -270,6 +270,39 @@ def restriction_from_attribute_spec(attributes): return restr +def post_entity_categories(maps, **kwargs): + restrictions = {} + if kwargs["mds"]: + try: + ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"]) + except KeyError: + for ec_map in maps: + for attr in ec_map[""]: + restrictions[attr] = None + else: + for ec_map in maps: + for key, val in ec_map.items(): + if key == "": # always released + attrs = val + elif isinstance(key, tuple): + attrs = val + for _key in key: + try: + assert _key in ecs + except AssertionError: + attrs = [] + break + elif key in ecs: + attrs = val + else: + attrs = [] + + for attr in attrs: + restrictions[attr] = None + + return restrictions + + class Policy(object): """ handles restrictions on assertions """ @@ -329,85 +362,70 @@ class Policy(object): logger.debug("policy restrictions: %s" % self._restrictions) return self._restrictions - + + def get(self, attribute, sp_entity_id, default=None, post_func=None, + **kwargs): + """ + + :param attribute: + :param sp_entity_id: + :param default: + :param post_func: + :return: + """ + if not self._restrictions: + return default + + try: + try: + val = self._restrictions[sp_entity_id][attribute] + except KeyError: + try: + val = self._restrictions["default"][attribute] + except KeyError: + val = None + except KeyError: + val = None + + if val is None: + return default + elif post_func: + return post_func(val, sp_entity_id=sp_entity_id, **kwargs) + else: + return val + def get_nameid_format(self, sp_entity_id): """ Get the NameIDFormat to used for the entity id :param: The SP entity ID :retur: The format """ - try: - form = self._restrictions[sp_entity_id]["nameid_format"] - except KeyError: - try: - form = self._restrictions["default"]["nameid_format"] - except KeyError: - form = saml.NAMEID_FORMAT_TRANSIENT - - return form - + return self.get("nameid_format", sp_entity_id, + saml.NAMEID_FORMAT_TRANSIENT) + def get_name_form(self, sp_entity_id): """ Get the NameFormat to used for the entity id :param: The SP entity ID :retur: The format """ - form = NAME_FORMAT_URI - - try: - form = self._restrictions[sp_entity_id]["name_form"] - except TypeError: - pass - except KeyError: - try: - form = self._restrictions["default"]["name_form"] - except KeyError: - pass - - return form - + + return self.get("name_format", sp_entity_id, NAME_FORMAT_URI) + def get_lifetime(self, sp_entity_id): """ The lifetime of the assertion :param sp_entity_id: The SP entity ID :param: lifetime as a dictionary """ # default is a hour - spec = {"hours": 1} - if not self._restrictions: - return spec - - try: - spec = self._restrictions[sp_entity_id]["lifetime"] - except KeyError: - try: - spec = self._restrictions["default"]["lifetime"] - except KeyError: - pass - - return spec - - def get_attribute_restriction(self, sp_entity_id): + return self.get("lifetime", sp_entity_id, {"hours": 1}) + + def get_attribute_restrictions(self, sp_entity_id): """ Return the attribute restriction for SP that want the information :param sp_entity_id: The SP entity ID :return: The restrictions """ - - if not self._restrictions: - return None - - try: - try: - restrictions = self._restrictions[sp_entity_id][ - "attribute_restrictions"] - except KeyError: - try: - restrictions = self._restrictions["default"][ - "attribute_restrictions"] - except KeyError: - restrictions = None - except KeyError: - restrictions = None - - return restrictions + + return self.get("attribute_restrictions", sp_entity_id) def entity_category_attributes(self, ec): if not self._restrictions: @@ -421,59 +439,18 @@ class Policy(object): pass return [] - def get_entity_categories_restriction(self, sp_entity_id, mds): + def get_entity_categories(self, sp_entity_id, mds): """ :param sp_entity_id: :param mds: MetadataStore instance - :return: A dictionary with restrictionsmetat + :return: A dictionary with restrictions """ - if not self._restrictions: - return None - restrictions = {} - ec_maps = [] - try: - try: - ec_maps = self._restrictions[sp_entity_id]["entity_categories"] - except KeyError: - try: - ec_maps = self._restrictions["default"]["entity_categories"] - except KeyError: - pass - except KeyError: - pass + kwargs = {"mds": mds} - if ec_maps: - if mds: - try: - ecs = mds.entity_categories(sp_entity_id) - except KeyError: - for ec_map in ec_maps: - for attr in ec_map[""]: - restrictions[attr] = None - else: - for ec_map in ec_maps: - for key, val in ec_map.items(): - if key == "": # always released - attrs = val - elif isinstance(key, tuple): - attrs = val - for _key in key: - try: - assert _key in ecs - except AssertionError: - attrs = [] - break - elif key in ecs: - attrs = val - else: - attrs = [] - - for attr in attrs: - restrictions[attr] = None - - return restrictions + return self.get("entity_categories", sp_entity_id, default={}, + post_func=post_entity_categories, **kwargs) def not_on_or_after(self, sp_entity_id): """ When the assertion stops being valid, should not be @@ -500,10 +477,9 @@ class Policy(object): :return: A possibly modified AVA """ - _rest = self.get_attribute_restriction(sp_entity_id) + _rest = self.get_attribute_restrictions(sp_entity_id) if _rest is None: - _rest = self.get_entity_categories_restriction(sp_entity_id, - mdstore) + _rest = self.get_entity_categories(sp_entity_id, mdstore) logger.debug("filter based on: %s" % _rest) ava = filter_attribute_value_assertions(ava, _rest) @@ -543,6 +519,17 @@ class Policy(object): audience=[factory(saml.Audience, text=sp_entity_id)])]) + def get_sign(self, sp_entity_id): + """ + Possible choices + "sign": ["response", "assertion", "on_demand"] + + :param sp_entity_id: + :return: + """ + + return self.get("sign", sp_entity_id, []) + class EntityCategories(object): pass diff --git a/src/saml2/extension/__init__.py b/src/saml2/extension/__init__.py index 836db2e..34752f1 100644 --- a/src/saml2/extension/__init__.py +++ b/src/saml2/extension/__init__.py @@ -1,3 +1,4 @@ # metadata extensions mainly __author__ = 'rolandh' -__all__ = ["dri", "mdrpi", "mdui", "shibmd", "idpdisc"] \ No newline at end of file +__all__ = ["dri", "mdrpi", "mdui", "shibmd", "idpdisc", 'algsupport', + 'mdattr', 'ui'] \ No newline at end of file