Fixed a problem with filtering assertion by required/optional attributes.

This commit is contained in:
Roland Hedberg
2014-03-27 11:12:41 +01:00
parent e28cf613a2
commit eeb4b5d694
5 changed files with 36 additions and 16 deletions

View File

@@ -24,7 +24,7 @@ import xmlenc
from saml2 import saml from saml2 import saml
from saml2.time_util import instant, in_a_while from saml2.time_util import instant, in_a_while
from saml2.attribute_converter import from_local from saml2.attribute_converter import from_local, get_local_name
from saml2.s_utils import sid, MissingValue from saml2.s_utils import sid, MissingValue
from saml2.s_utils import factory from saml2.s_utils import factory
from saml2.s_utils import assertion_factory from saml2.s_utils import assertion_factory
@@ -78,7 +78,7 @@ def _match(attr, ava):
return None return None
def filter_on_attributes(ava, required=None, optional=None): def filter_on_attributes(ava, required=None, optional=None, acs=None):
""" Filter """ Filter
:param ava: An attribute value assertion as a dictionary :param ava: An attribute value assertion as a dictionary
@@ -98,10 +98,15 @@ def filter_on_attributes(ava, required=None, optional=None):
nform = "" nform = ""
for nform in ["friendly_name", "name"]: for nform in ["friendly_name", "name"]:
try: try:
_fn = _match(attr[nform], ava) _name = attr[nform]
except KeyError: except KeyError:
pass if nform == "friendly_name":
_name = get_local_name(acs, attr["name"],
attr["name_format"])
else: else:
continue
_fn = _match(_name, ava)
if _fn: if _fn:
try: try:
values = [av["text"] for av in attr["attribute_value"]] values = [av["text"] for av in attr["attribute_value"]]
@@ -311,6 +316,7 @@ class Policy(object):
self.compile(restrictions) self.compile(restrictions)
else: else:
self._restrictions = None self._restrictions = None
self.acs = []
def compile(self, restrictions): def compile(self, restrictions):
""" This is only for IdPs or AAs, and it's about limiting what """ This is only for IdPs or AAs, and it's about limiting what
@@ -484,7 +490,8 @@ class Policy(object):
ava = filter_attribute_value_assertions(ava, _rest) ava = filter_attribute_value_assertions(ava, _rest)
if required or optional: if required or optional:
ava = filter_on_attributes(ava, required, optional) logger.debug("required: %s, optional: %s" % (required, optional))
ava = filter_on_attributes(ava, required, optional, self.acs)
return ava return ava
@@ -540,6 +547,7 @@ class Assertion(dict):
def __init__(self, dic=None): def __init__(self, dic=None):
dict.__init__(self, dic) dict.__init__(self, dic)
self.acs = []
@staticmethod @staticmethod
def _authn_context_decl(decl, authn_auth=None): def _authn_context_decl(decl, authn_auth=None):
@@ -727,6 +735,8 @@ class Assertion(dict):
:param metadata: Metadata to use :param metadata: Metadata to use
:return: The resulting AVA after the policy is applied :return: The resulting AVA after the policy is applied
""" """
policy.acs = self.acs
ava = policy.restrict(self, sp_entity_id, metadata) ava = policy.restrict(self, sp_entity_id, metadata)
self.update(ava) self.update(ava)
return ava return ava

View File

@@ -255,6 +255,13 @@ def to_local_name(acs, attr):
return attr.friendly_name return attr.friendly_name
def get_local_name(acs, attr, name_format):
for aconv in acs:
#print ac.format, name_format
if aconv.name_format == name_format:
return aconv._fro[attr]
def d_to_local_name(acs, attr): def d_to_local_name(acs, attr):
""" """
:param acs: List of AttributeConverter instances :param acs: List of AttributeConverter instances

View File

@@ -177,6 +177,7 @@ MAP = {
'edupersonaffiliation': EDUPERSON_OID+'1', 'edupersonaffiliation': EDUPERSON_OID+'1',
'eduPersonPrincipalName': EDUPERSON_OID+'6', 'eduPersonPrincipalName': EDUPERSON_OID+'6',
'edupersonprincipalname': EDUPERSON_OID+'6', 'edupersonprincipalname': EDUPERSON_OID+'6',
'eppn': EDUPERSON_OID+'6',
'localityName': X500ATTR_OID+'7', 'localityName': X500ATTR_OID+'7',
'owner': X500ATTR_OID+'32', 'owner': X500ATTR_OID+'32',
'norEduOrgUnitUniqueNumber': NOREDUPERSON_OID+'2', 'norEduOrgUnitUniqueNumber': NOREDUPERSON_OID+'2',

View File

@@ -308,6 +308,7 @@ class Server(Entity):
#if identity: #if identity:
_issuer = self._issuer(issuer) _issuer = self._issuer(issuer)
ast = Assertion(identity) ast = Assertion(identity)
ast.acs = self.config.getattr("attribute_converters", "idp")
if policy is None: if policy is None:
policy = Policy() policy = Policy()
try: try:

View File

@@ -1,2 +1,3 @@
#!/bin/sh
curl -O -G http://md.swamid.se/md/swamid-2.0.xml curl -O -G http://md.swamid.se/md/swamid-2.0.xml
mdexport.py -t local -o swamid2.md swamid-2.0.xml mdexport.py -t local -o swamid2.md swamid-2.0.xml