Deal with entity category (CoCo) that have more complex evaluation rules.

This commit is contained in:
Roland Hedberg
2016-05-16 20:48:56 +02:00
parent 9c04dc7ebb
commit 9ef92af7d8
4 changed files with 71 additions and 50 deletions

View File

@@ -356,9 +356,14 @@ class ACS(Service):
return resp(self.environ, self.start_response)
try:
conv_info = {'remote_addr': self.environ['REMOTE_ADDR'],
'request_uri': self.environ['REQUEST_URI'],
'entity_id': self.sp.config.entityid,
'endpoints': self.sp.config.getattr('endpoints', 'sp')}
self.response = self.sp.parse_authn_request_response(
response, binding, self.outstanding_queries,
self.cache.outstanding_certs)
self.cache.outstanding_certs, conv_info=conv_info)
except UnknownPrincipal as excp:
logger.error("UnknownPrincipal: %s", excp)
resp = ServiceError("UnknownPrincipal: %s" % (excp,))

View File

@@ -82,10 +82,12 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
try:
friendly_name = attr["friendly_name"]
except KeyError:
friendly_name = get_local_name(acs, attr["name"], attr["name_format"])
friendly_name = get_local_name(acs, attr["name"],
attr["name_format"])
_fn = _match(friendly_name, ava)
if not _fn: # In the unlikely case that someone has provided us with URIs as attribute names
if not _fn: # In the unlikely case that someone has provided us with
# URIs as attribute names
_fn = _match(attr["name"], ava)
return _fn
@@ -152,8 +154,8 @@ def filter_on_demands(ava, required=None, optional=None):
for val in vals:
if val not in ava[lava[attr]]:
raise MissingValue(
"Required attribute value missing: %s,%s" % (attr,
val))
"Required attribute value missing: %s,%s" % (attr,
val))
else:
raise MissingValue("Required attribute missing: %s" % (attr,))
@@ -266,6 +268,11 @@ def restriction_from_attribute_spec(attributes):
def post_entity_categories(maps, **kwargs):
restrictions = {}
try:
required = [d['friendly_name'].lower() for d in kwargs['required']]
except KeyError:
required = []
if kwargs["mds"]:
try:
ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"])
@@ -275,11 +282,14 @@ def post_entity_categories(maps, **kwargs):
restrictions[attr] = None
else:
for ec_map in maps:
for key, val in ec_map.items():
for key, (atlist, only_required) in ec_map.items():
if key == "": # always released
attrs = val
attrs = atlist
elif isinstance(key, tuple):
attrs = val
if only_required:
attrs = [a for a in atlist if a in required]
else:
attrs = atlist
for _key in key:
try:
assert _key in ecs
@@ -287,7 +297,10 @@ def post_entity_categories(maps, **kwargs):
attrs = []
break
elif key in ecs:
attrs = val
if only_required:
attrs = [a for a in atlist if a in required]
else:
attrs = atlist
else:
attrs = []
@@ -332,10 +345,15 @@ class Policy(object):
ecs = []
for cat in items:
_mod = importlib.import_module(
"saml2.entity_category.%s" % cat)
"saml2.entity_category.%s" % cat)
_ec = {}
for key, items in _mod.RELEASE.items():
_ec[key] = [k.lower() for k in items]
alist = [k.lower() for k in items]
try:
_only_required = _mod.ONLY_REQUIRED[key]
except (AttributeError, KeyError):
_only_required = False
_ec[key] = (alist, _only_required)
ecs.append(_ec)
spec["entity_categories"] = ecs
try:
@@ -444,7 +462,7 @@ class Policy(object):
pass
return []
def get_entity_categories(self, sp_entity_id, mds):
def get_entity_categories(self, sp_entity_id, mds, required):
"""
:param sp_entity_id:
@@ -452,7 +470,7 @@ class Policy(object):
:return: A dictionary with restrictions
"""
kwargs = {"mds": mds}
kwargs = {"mds": mds, 'required': required}
return self.get("entity_categories", sp_entity_id, default={},
post_func=post_entity_categories, **kwargs)
@@ -483,19 +501,15 @@ class Policy(object):
"""
_ava = None
if required or optional:
_rest = self.get_entity_categories(sp_entity_id, mdstore, required)
if _rest:
_ava = filter_attribute_value_assertions(ava.copy(), _rest)
elif required or optional:
logger.debug("required: %s, optional: %s", required, optional)
_ava = filter_on_attributes(
ava.copy(), required, optional, self.acs,
self.get_fail_on_missing_requested(sp_entity_id))
_rest = self.get_entity_categories(sp_entity_id, mdstore)
if _rest:
ava_ec = filter_attribute_value_assertions(ava.copy(), _rest)
if _ava is None:
_ava = ava_ec
else:
_ava.update(ava_ec)
ava.copy(), required, optional, self.acs,
self.get_fail_on_missing_requested(sp_entity_id))
_rest = self.get_attribute_restrictions(sp_entity_id)
if _rest:
@@ -537,9 +551,9 @@ class Policy(object):
# How long might depend on who's getting it
not_on_or_after=self.not_on_or_after(sp_entity_id),
audience_restriction=[factory(
saml.AudienceRestriction,
audience=[factory(saml.Audience,
text=sp_entity_id)])])
saml.AudienceRestriction,
audience=[factory(saml.Audience,
text=sp_entity_id)])])
def get_sign(self, sp_entity_id):
"""
@@ -569,7 +583,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
return factory(saml.AuthnContext,
authn_context_class_ref=cntx_class,
authenticating_authority=factory(
saml.AuthenticatingAuthority, text=authn_auth))
saml.AuthenticatingAuthority, text=authn_auth))
else:
return factory(saml.AuthnContext,
authn_context_class_ref=cntx_class)
@@ -585,7 +599,7 @@ def _authn_context_decl(decl, authn_auth=None):
return factory(saml.AuthnContext,
authn_context_decl=decl,
authenticating_authority=factory(
saml.AuthenticatingAuthority, text=authn_auth))
saml.AuthenticatingAuthority, text=authn_auth))
def _authn_context_decl_ref(decl_ref, authn_auth=None):
@@ -598,7 +612,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
return factory(saml.AuthnContext,
authn_context_decl_ref=decl_ref,
authenticating_authority=factory(
saml.AuthenticatingAuthority, text=authn_auth))
saml.AuthenticatingAuthority, text=authn_auth))
def authn_statement(authn_class=None, authn_auth=None,
@@ -624,29 +638,29 @@ def authn_statement(authn_class=None, authn_auth=None,
if authn_class:
res = factory(
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid(),
authn_context=_authn_context_class_ref(
authn_class, authn_auth))
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid(),
authn_context=_authn_context_class_ref(
authn_class, authn_auth))
elif authn_decl:
res = factory(
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid(),
authn_context=_authn_context_decl(authn_decl, authn_auth))
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid(),
authn_context=_authn_context_decl(authn_decl, authn_auth))
elif authn_decl_ref:
res = factory(
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid(),
authn_context=_authn_context_decl_ref(authn_decl_ref,
authn_auth))
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid(),
authn_context=_authn_context_decl_ref(authn_decl_ref,
authn_auth))
else:
res = factory(
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid())
saml.AuthnStatement,
authn_instant=_instant,
session_index=sid())
if subject_locality:
res.subject_locality = saml.SubjectLocality(text=subject_locality)
@@ -688,7 +702,8 @@ def do_subject(policy, sp_entity_id, name_id, **farg):
specs = farg['subject_confirmation']
if isinstance(specs, list):
res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in specs]
res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in
specs]
else:
res = [do_subject_confirmation(policy, sp_entity_id, **specs)]
@@ -736,7 +751,7 @@ class Assertion(dict):
_name_format = NAME_FORMAT_URI
attr_statement = saml.AttributeStatement(attribute=from_local(
attrconvs, self, _name_format))
attrconvs, self, _name_format))
if encrypt == "attributes":
for attr in attr_statement.attribute:

View File

@@ -26,7 +26,7 @@ MAP = {
EDUPERSON_OID+'6': 'eduPersonPrimaryAffiliation',
EDUPERSON_OID+'7': 'eduPersonPrimaryOrgUnitDN',
EDUPERSON_OID+'8': 'eduPersonPrincipalName',
EDUPERSON_OID+'9': 'eduPersonPrincipalNamePrior',
EDUPERSON_OID+'9': 'eduPersonPrincipalName',
EDUPERSON_OID+'10': 'eduPersonScopedAffiliation',
EDUPERSON_OID+'11': 'eduPersonTargetedID',
EDUPERSON_OID+'12': 'eduPersonAssurance',

View File

@@ -12,3 +12,4 @@ RELEASE = {
"schacHomeOrganization"]
}
ONLY_REQUIRED = {COCO: True}