Refactored class methods

This commit is contained in:
Roland Hedberg
2014-03-08 19:40:54 +01:00
parent 6b1bc50961
commit 7b41b464c5
3 changed files with 97 additions and 109 deletions

View File

@@ -362,7 +362,7 @@ class ACS(Service):
def verify_attributes(self, ava):
logger.info("SP: %s" % self.sp.config.entityid)
rest = POLICY.get_entity_categories_restriction(
rest = POLICY.get_entity_categories(
self.sp.config.entityid, self.sp.metadata)
akeys = [k.lower() for k in ava.keys()]

View File

@@ -270,6 +270,39 @@ def restriction_from_attribute_spec(attributes):
return restr
def post_entity_categories(maps, **kwargs):
restrictions = {}
if kwargs["mds"]:
try:
ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"])
except KeyError:
for ec_map in maps:
for attr in ec_map[""]:
restrictions[attr] = None
else:
for ec_map in maps:
for key, val in ec_map.items():
if key == "": # always released
attrs = val
elif isinstance(key, tuple):
attrs = val
for _key in key:
try:
assert _key in ecs
except AssertionError:
attrs = []
break
elif key in ecs:
attrs = val
else:
attrs = []
for attr in attrs:
restrictions[attr] = None
return restrictions
class Policy(object):
""" handles restrictions on assertions """
@@ -330,39 +363,52 @@ class Policy(object):
return self._restrictions
def get(self, attribute, sp_entity_id, default=None, post_func=None,
**kwargs):
"""
:param attribute:
:param sp_entity_id:
:param default:
:param post_func:
:return:
"""
if not self._restrictions:
return default
try:
try:
val = self._restrictions[sp_entity_id][attribute]
except KeyError:
try:
val = self._restrictions["default"][attribute]
except KeyError:
val = None
except KeyError:
val = None
if val is None:
return default
elif post_func:
return post_func(val, sp_entity_id=sp_entity_id, **kwargs)
else:
return val
def get_nameid_format(self, sp_entity_id):
""" Get the NameIDFormat to used for the entity id
:param: The SP entity ID
:retur: The format
"""
try:
form = self._restrictions[sp_entity_id]["nameid_format"]
except KeyError:
try:
form = self._restrictions["default"]["nameid_format"]
except KeyError:
form = saml.NAMEID_FORMAT_TRANSIENT
return form
return self.get("nameid_format", sp_entity_id,
saml.NAMEID_FORMAT_TRANSIENT)
def get_name_form(self, sp_entity_id):
""" Get the NameFormat to used for the entity id
:param: The SP entity ID
:retur: The format
"""
form = NAME_FORMAT_URI
try:
form = self._restrictions[sp_entity_id]["name_form"]
except TypeError:
pass
except KeyError:
try:
form = self._restrictions["default"]["name_form"]
except KeyError:
pass
return form
return self.get("name_format", sp_entity_id, NAME_FORMAT_URI)
def get_lifetime(self, sp_entity_id):
""" The lifetime of the assertion
@@ -370,44 +416,16 @@ class Policy(object):
:param: lifetime as a dictionary
"""
# default is a hour
spec = {"hours": 1}
if not self._restrictions:
return spec
return self.get("lifetime", sp_entity_id, {"hours": 1})
try:
spec = self._restrictions[sp_entity_id]["lifetime"]
except KeyError:
try:
spec = self._restrictions["default"]["lifetime"]
except KeyError:
pass
return spec
def get_attribute_restriction(self, sp_entity_id):
def get_attribute_restrictions(self, sp_entity_id):
""" Return the attribute restriction for SP that want the information
:param sp_entity_id: The SP entity ID
:return: The restrictions
"""
if not self._restrictions:
return None
try:
try:
restrictions = self._restrictions[sp_entity_id][
"attribute_restrictions"]
except KeyError:
try:
restrictions = self._restrictions["default"][
"attribute_restrictions"]
except KeyError:
restrictions = None
except KeyError:
restrictions = None
return restrictions
return self.get("attribute_restrictions", sp_entity_id)
def entity_category_attributes(self, ec):
if not self._restrictions:
@@ -421,59 +439,18 @@ class Policy(object):
pass
return []
def get_entity_categories_restriction(self, sp_entity_id, mds):
def get_entity_categories(self, sp_entity_id, mds):
"""
:param sp_entity_id:
:param mds: MetadataStore instance
:return: A dictionary with restrictionsmetat
:return: A dictionary with restrictions
"""
if not self._restrictions:
return None
restrictions = {}
ec_maps = []
try:
try:
ec_maps = self._restrictions[sp_entity_id]["entity_categories"]
except KeyError:
try:
ec_maps = self._restrictions["default"]["entity_categories"]
except KeyError:
pass
except KeyError:
pass
kwargs = {"mds": mds}
if ec_maps:
if mds:
try:
ecs = mds.entity_categories(sp_entity_id)
except KeyError:
for ec_map in ec_maps:
for attr in ec_map[""]:
restrictions[attr] = None
else:
for ec_map in ec_maps:
for key, val in ec_map.items():
if key == "": # always released
attrs = val
elif isinstance(key, tuple):
attrs = val
for _key in key:
try:
assert _key in ecs
except AssertionError:
attrs = []
break
elif key in ecs:
attrs = val
else:
attrs = []
for attr in attrs:
restrictions[attr] = None
return restrictions
return self.get("entity_categories", sp_entity_id, default={},
post_func=post_entity_categories, **kwargs)
def not_on_or_after(self, sp_entity_id):
""" When the assertion stops being valid, should not be
@@ -500,10 +477,9 @@ class Policy(object):
:return: A possibly modified AVA
"""
_rest = self.get_attribute_restriction(sp_entity_id)
_rest = self.get_attribute_restrictions(sp_entity_id)
if _rest is None:
_rest = self.get_entity_categories_restriction(sp_entity_id,
mdstore)
_rest = self.get_entity_categories(sp_entity_id, mdstore)
logger.debug("filter based on: %s" % _rest)
ava = filter_attribute_value_assertions(ava, _rest)
@@ -543,6 +519,17 @@ class Policy(object):
audience=[factory(saml.Audience,
text=sp_entity_id)])])
def get_sign(self, sp_entity_id):
"""
Possible choices
"sign": ["response", "assertion", "on_demand"]
:param sp_entity_id:
:return:
"""
return self.get("sign", sp_entity_id, [])
class EntityCategories(object):
pass

View File

@@ -1,3 +1,4 @@
# metadata extensions mainly
__author__ = 'rolandh'
__all__ = ["dri", "mdrpi", "mdui", "shibmd", "idpdisc"]
__all__ = ["dri", "mdrpi", "mdui", "shibmd", "idpdisc", 'algsupport',
'mdattr', 'ui']