Deal with entity category (CoCo) that have more complex evaluation rules.
This commit is contained in:
@@ -356,9 +356,14 @@ class ACS(Service):
|
||||
return resp(self.environ, self.start_response)
|
||||
|
||||
try:
|
||||
conv_info = {'remote_addr': self.environ['REMOTE_ADDR'],
|
||||
'request_uri': self.environ['REQUEST_URI'],
|
||||
'entity_id': self.sp.config.entityid,
|
||||
'endpoints': self.sp.config.getattr('endpoints', 'sp')}
|
||||
|
||||
self.response = self.sp.parse_authn_request_response(
|
||||
response, binding, self.outstanding_queries,
|
||||
self.cache.outstanding_certs)
|
||||
self.cache.outstanding_certs, conv_info=conv_info)
|
||||
except UnknownPrincipal as excp:
|
||||
logger.error("UnknownPrincipal: %s", excp)
|
||||
resp = ServiceError("UnknownPrincipal: %s" % (excp,))
|
||||
|
@@ -82,10 +82,12 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
|
||||
try:
|
||||
friendly_name = attr["friendly_name"]
|
||||
except KeyError:
|
||||
friendly_name = get_local_name(acs, attr["name"], attr["name_format"])
|
||||
friendly_name = get_local_name(acs, attr["name"],
|
||||
attr["name_format"])
|
||||
|
||||
_fn = _match(friendly_name, ava)
|
||||
if not _fn: # In the unlikely case that someone has provided us with URIs as attribute names
|
||||
if not _fn: # In the unlikely case that someone has provided us with
|
||||
# URIs as attribute names
|
||||
_fn = _match(attr["name"], ava)
|
||||
|
||||
return _fn
|
||||
@@ -152,8 +154,8 @@ def filter_on_demands(ava, required=None, optional=None):
|
||||
for val in vals:
|
||||
if val not in ava[lava[attr]]:
|
||||
raise MissingValue(
|
||||
"Required attribute value missing: %s,%s" % (attr,
|
||||
val))
|
||||
"Required attribute value missing: %s,%s" % (attr,
|
||||
val))
|
||||
else:
|
||||
raise MissingValue("Required attribute missing: %s" % (attr,))
|
||||
|
||||
@@ -266,6 +268,11 @@ def restriction_from_attribute_spec(attributes):
|
||||
|
||||
def post_entity_categories(maps, **kwargs):
|
||||
restrictions = {}
|
||||
try:
|
||||
required = [d['friendly_name'].lower() for d in kwargs['required']]
|
||||
except KeyError:
|
||||
required = []
|
||||
|
||||
if kwargs["mds"]:
|
||||
try:
|
||||
ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"])
|
||||
@@ -275,11 +282,14 @@ def post_entity_categories(maps, **kwargs):
|
||||
restrictions[attr] = None
|
||||
else:
|
||||
for ec_map in maps:
|
||||
for key, val in ec_map.items():
|
||||
for key, (atlist, only_required) in ec_map.items():
|
||||
if key == "": # always released
|
||||
attrs = val
|
||||
attrs = atlist
|
||||
elif isinstance(key, tuple):
|
||||
attrs = val
|
||||
if only_required:
|
||||
attrs = [a for a in atlist if a in required]
|
||||
else:
|
||||
attrs = atlist
|
||||
for _key in key:
|
||||
try:
|
||||
assert _key in ecs
|
||||
@@ -287,7 +297,10 @@ def post_entity_categories(maps, **kwargs):
|
||||
attrs = []
|
||||
break
|
||||
elif key in ecs:
|
||||
attrs = val
|
||||
if only_required:
|
||||
attrs = [a for a in atlist if a in required]
|
||||
else:
|
||||
attrs = atlist
|
||||
else:
|
||||
attrs = []
|
||||
|
||||
@@ -332,10 +345,15 @@ class Policy(object):
|
||||
ecs = []
|
||||
for cat in items:
|
||||
_mod = importlib.import_module(
|
||||
"saml2.entity_category.%s" % cat)
|
||||
"saml2.entity_category.%s" % cat)
|
||||
_ec = {}
|
||||
for key, items in _mod.RELEASE.items():
|
||||
_ec[key] = [k.lower() for k in items]
|
||||
alist = [k.lower() for k in items]
|
||||
try:
|
||||
_only_required = _mod.ONLY_REQUIRED[key]
|
||||
except (AttributeError, KeyError):
|
||||
_only_required = False
|
||||
_ec[key] = (alist, _only_required)
|
||||
ecs.append(_ec)
|
||||
spec["entity_categories"] = ecs
|
||||
try:
|
||||
@@ -444,7 +462,7 @@ class Policy(object):
|
||||
pass
|
||||
return []
|
||||
|
||||
def get_entity_categories(self, sp_entity_id, mds):
|
||||
def get_entity_categories(self, sp_entity_id, mds, required):
|
||||
"""
|
||||
|
||||
:param sp_entity_id:
|
||||
@@ -452,7 +470,7 @@ class Policy(object):
|
||||
:return: A dictionary with restrictions
|
||||
"""
|
||||
|
||||
kwargs = {"mds": mds}
|
||||
kwargs = {"mds": mds, 'required': required}
|
||||
|
||||
return self.get("entity_categories", sp_entity_id, default={},
|
||||
post_func=post_entity_categories, **kwargs)
|
||||
@@ -483,19 +501,15 @@ class Policy(object):
|
||||
"""
|
||||
|
||||
_ava = None
|
||||
if required or optional:
|
||||
|
||||
_rest = self.get_entity_categories(sp_entity_id, mdstore, required)
|
||||
if _rest:
|
||||
_ava = filter_attribute_value_assertions(ava.copy(), _rest)
|
||||
elif required or optional:
|
||||
logger.debug("required: %s, optional: %s", required, optional)
|
||||
_ava = filter_on_attributes(
|
||||
ava.copy(), required, optional, self.acs,
|
||||
self.get_fail_on_missing_requested(sp_entity_id))
|
||||
|
||||
_rest = self.get_entity_categories(sp_entity_id, mdstore)
|
||||
if _rest:
|
||||
ava_ec = filter_attribute_value_assertions(ava.copy(), _rest)
|
||||
if _ava is None:
|
||||
_ava = ava_ec
|
||||
else:
|
||||
_ava.update(ava_ec)
|
||||
ava.copy(), required, optional, self.acs,
|
||||
self.get_fail_on_missing_requested(sp_entity_id))
|
||||
|
||||
_rest = self.get_attribute_restrictions(sp_entity_id)
|
||||
if _rest:
|
||||
@@ -537,9 +551,9 @@ class Policy(object):
|
||||
# How long might depend on who's getting it
|
||||
not_on_or_after=self.not_on_or_after(sp_entity_id),
|
||||
audience_restriction=[factory(
|
||||
saml.AudienceRestriction,
|
||||
audience=[factory(saml.Audience,
|
||||
text=sp_entity_id)])])
|
||||
saml.AudienceRestriction,
|
||||
audience=[factory(saml.Audience,
|
||||
text=sp_entity_id)])])
|
||||
|
||||
def get_sign(self, sp_entity_id):
|
||||
"""
|
||||
@@ -569,7 +583,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
|
||||
return factory(saml.AuthnContext,
|
||||
authn_context_class_ref=cntx_class,
|
||||
authenticating_authority=factory(
|
||||
saml.AuthenticatingAuthority, text=authn_auth))
|
||||
saml.AuthenticatingAuthority, text=authn_auth))
|
||||
else:
|
||||
return factory(saml.AuthnContext,
|
||||
authn_context_class_ref=cntx_class)
|
||||
@@ -585,7 +599,7 @@ def _authn_context_decl(decl, authn_auth=None):
|
||||
return factory(saml.AuthnContext,
|
||||
authn_context_decl=decl,
|
||||
authenticating_authority=factory(
|
||||
saml.AuthenticatingAuthority, text=authn_auth))
|
||||
saml.AuthenticatingAuthority, text=authn_auth))
|
||||
|
||||
|
||||
def _authn_context_decl_ref(decl_ref, authn_auth=None):
|
||||
@@ -598,7 +612,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
|
||||
return factory(saml.AuthnContext,
|
||||
authn_context_decl_ref=decl_ref,
|
||||
authenticating_authority=factory(
|
||||
saml.AuthenticatingAuthority, text=authn_auth))
|
||||
saml.AuthenticatingAuthority, text=authn_auth))
|
||||
|
||||
|
||||
def authn_statement(authn_class=None, authn_auth=None,
|
||||
@@ -624,29 +638,29 @@ def authn_statement(authn_class=None, authn_auth=None,
|
||||
|
||||
if authn_class:
|
||||
res = factory(
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid(),
|
||||
authn_context=_authn_context_class_ref(
|
||||
authn_class, authn_auth))
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid(),
|
||||
authn_context=_authn_context_class_ref(
|
||||
authn_class, authn_auth))
|
||||
elif authn_decl:
|
||||
res = factory(
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid(),
|
||||
authn_context=_authn_context_decl(authn_decl, authn_auth))
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid(),
|
||||
authn_context=_authn_context_decl(authn_decl, authn_auth))
|
||||
elif authn_decl_ref:
|
||||
res = factory(
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid(),
|
||||
authn_context=_authn_context_decl_ref(authn_decl_ref,
|
||||
authn_auth))
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid(),
|
||||
authn_context=_authn_context_decl_ref(authn_decl_ref,
|
||||
authn_auth))
|
||||
else:
|
||||
res = factory(
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid())
|
||||
saml.AuthnStatement,
|
||||
authn_instant=_instant,
|
||||
session_index=sid())
|
||||
|
||||
if subject_locality:
|
||||
res.subject_locality = saml.SubjectLocality(text=subject_locality)
|
||||
@@ -688,7 +702,8 @@ def do_subject(policy, sp_entity_id, name_id, **farg):
|
||||
specs = farg['subject_confirmation']
|
||||
|
||||
if isinstance(specs, list):
|
||||
res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in specs]
|
||||
res = [do_subject_confirmation(policy, sp_entity_id, **s) for s in
|
||||
specs]
|
||||
else:
|
||||
res = [do_subject_confirmation(policy, sp_entity_id, **specs)]
|
||||
|
||||
@@ -736,7 +751,7 @@ class Assertion(dict):
|
||||
_name_format = NAME_FORMAT_URI
|
||||
|
||||
attr_statement = saml.AttributeStatement(attribute=from_local(
|
||||
attrconvs, self, _name_format))
|
||||
attrconvs, self, _name_format))
|
||||
|
||||
if encrypt == "attributes":
|
||||
for attr in attr_statement.attribute:
|
||||
|
@@ -26,7 +26,7 @@ MAP = {
|
||||
EDUPERSON_OID+'6': 'eduPersonPrimaryAffiliation',
|
||||
EDUPERSON_OID+'7': 'eduPersonPrimaryOrgUnitDN',
|
||||
EDUPERSON_OID+'8': 'eduPersonPrincipalName',
|
||||
EDUPERSON_OID+'9': 'eduPersonPrincipalNamePrior',
|
||||
EDUPERSON_OID+'9': 'eduPersonPrincipalName',
|
||||
EDUPERSON_OID+'10': 'eduPersonScopedAffiliation',
|
||||
EDUPERSON_OID+'11': 'eduPersonTargetedID',
|
||||
EDUPERSON_OID+'12': 'eduPersonAssurance',
|
||||
|
@@ -12,3 +12,4 @@ RELEASE = {
|
||||
"schacHomeOrganization"]
|
||||
}
|
||||
|
||||
ONLY_REQUIRED = {COCO: True}
|
||||
|
Reference in New Issue
Block a user