Fixed a problem with filtering assertion by required/optional attributes.

This commit is contained in:
Roland Hedberg
2014-03-27 11:12:41 +01:00
parent e28cf613a2
commit eeb4b5d694
5 changed files with 36 additions and 16 deletions

View File

@@ -24,7 +24,7 @@ import xmlenc
from saml2 import saml from saml2 import saml
from saml2.time_util import instant, in_a_while from saml2.time_util import instant, in_a_while
from saml2.attribute_converter import from_local from saml2.attribute_converter import from_local, get_local_name
from saml2.s_utils import sid, MissingValue from saml2.s_utils import sid, MissingValue
from saml2.s_utils import factory from saml2.s_utils import factory
from saml2.s_utils import assertion_factory from saml2.s_utils import assertion_factory
@@ -78,7 +78,7 @@ def _match(attr, ava):
return None return None
def filter_on_attributes(ava, required=None, optional=None): def filter_on_attributes(ava, required=None, optional=None, acs=None):
""" Filter """ Filter
:param ava: An attribute value assertion as a dictionary :param ava: An attribute value assertion as a dictionary
@@ -98,18 +98,23 @@ def filter_on_attributes(ava, required=None, optional=None):
nform = "" nform = ""
for nform in ["friendly_name", "name"]: for nform in ["friendly_name", "name"]:
try: try:
_fn = _match(attr[nform], ava) _name = attr[nform]
except KeyError: except KeyError:
pass if nform == "friendly_name":
else: _name = get_local_name(acs, attr["name"],
if _fn: attr["name_format"])
try: else:
values = [av["text"] for av in attr["attribute_value"]] continue
except KeyError:
values = [] _fn = _match(_name, ava)
res[_fn] = _filter_values(ava[_fn], values, True) if _fn:
found = True try:
break values = [av["text"] for av in attr["attribute_value"]]
except KeyError:
values = []
res[_fn] = _filter_values(ava[_fn], values, True)
found = True
break
if not found: if not found:
raise MissingValue("Required attribute missing: '%s'" % ( raise MissingValue("Required attribute missing: '%s'" % (
@@ -311,7 +316,8 @@ class Policy(object):
self.compile(restrictions) self.compile(restrictions)
else: else:
self._restrictions = None self._restrictions = None
self.acs = []
def compile(self, restrictions): def compile(self, restrictions):
""" This is only for IdPs or AAs, and it's about limiting what """ This is only for IdPs or AAs, and it's about limiting what
is returned to the SP. is returned to the SP.
@@ -484,7 +490,8 @@ class Policy(object):
ava = filter_attribute_value_assertions(ava, _rest) ava = filter_attribute_value_assertions(ava, _rest)
if required or optional: if required or optional:
ava = filter_on_attributes(ava, required, optional) logger.debug("required: %s, optional: %s" % (required, optional))
ava = filter_on_attributes(ava, required, optional, self.acs)
return ava return ava
@@ -540,7 +547,8 @@ class Assertion(dict):
def __init__(self, dic=None): def __init__(self, dic=None):
dict.__init__(self, dic) dict.__init__(self, dic)
self.acs = []
@staticmethod @staticmethod
def _authn_context_decl(decl, authn_auth=None): def _authn_context_decl(decl, authn_auth=None):
""" """
@@ -727,6 +735,8 @@ class Assertion(dict):
:param metadata: Metadata to use :param metadata: Metadata to use
:return: The resulting AVA after the policy is applied :return: The resulting AVA after the policy is applied
""" """
policy.acs = self.acs
ava = policy.restrict(self, sp_entity_id, metadata) ava = policy.restrict(self, sp_entity_id, metadata)
self.update(ava) self.update(ava)
return ava return ava

View File

@@ -255,6 +255,13 @@ def to_local_name(acs, attr):
return attr.friendly_name return attr.friendly_name
def get_local_name(acs, attr, name_format):
for aconv in acs:
#print ac.format, name_format
if aconv.name_format == name_format:
return aconv._fro[attr]
def d_to_local_name(acs, attr): def d_to_local_name(acs, attr):
""" """
:param acs: List of AttributeConverter instances :param acs: List of AttributeConverter instances

View File

@@ -177,6 +177,7 @@ MAP = {
'edupersonaffiliation': EDUPERSON_OID+'1', 'edupersonaffiliation': EDUPERSON_OID+'1',
'eduPersonPrincipalName': EDUPERSON_OID+'6', 'eduPersonPrincipalName': EDUPERSON_OID+'6',
'edupersonprincipalname': EDUPERSON_OID+'6', 'edupersonprincipalname': EDUPERSON_OID+'6',
'eppn': EDUPERSON_OID+'6',
'localityName': X500ATTR_OID+'7', 'localityName': X500ATTR_OID+'7',
'owner': X500ATTR_OID+'32', 'owner': X500ATTR_OID+'32',
'norEduOrgUnitUniqueNumber': NOREDUPERSON_OID+'2', 'norEduOrgUnitUniqueNumber': NOREDUPERSON_OID+'2',

View File

@@ -308,6 +308,7 @@ class Server(Entity):
#if identity: #if identity:
_issuer = self._issuer(issuer) _issuer = self._issuer(issuer)
ast = Assertion(identity) ast = Assertion(identity)
ast.acs = self.config.getattr("attribute_converters", "idp")
if policy is None: if policy is None:
policy = Policy() policy = Policy()
try: try:

View File

@@ -1,2 +1,3 @@
#!/bin/sh
curl -O -G http://md.swamid.se/md/swamid-2.0.xml curl -O -G http://md.swamid.se/md/swamid-2.0.xml
mdexport.py -t local -o swamid2.md swamid-2.0.xml mdexport.py -t local -o swamid2.md swamid-2.0.xml