diff --git a/src/saml2/authn_context/__init__.py b/src/saml2/authn_context/__init__.py index 20316eb..f2c01b9 100644 --- a/src/saml2/authn_context/__init__.py +++ b/src/saml2/authn_context/__init__.py @@ -1,4 +1,5 @@ from saml2.saml import AuthnContext, AuthnContextClassRef +from saml2.samlp import RequestedAuthnContext __author__ = 'rolandh' @@ -150,8 +151,8 @@ class AuthnBroker(object): Given the authentication context find zero or more places where the user could be sent next. Ordered according to security level. - :param req_authn_context: The requested context as an AuthnContext - instance + :param req_authn_context: The requested context as an + RequestedAuthnContext instance :return: An URL """ @@ -164,17 +165,13 @@ class AuthnBroker(object): _cmp = "minimum" return self._pick_by_class_ref( req_authn_context.authn_context_class_ref.text, _cmp) - elif req_authn_context.authn_context_decl: - _decl = req_authn_context.authn_context_decl - key = _decl.c_namespace - _methods = [] - for _ref in self.db["key"][key]: - _dic = self.db["info"][_ref] - if self.match(_decl, _dic["decl"]): - _val = (_dic["method"], _ref) - if _val not in _methods: - _methods.append(_val) - return _methods + elif req_authn_context.authn_context_decl_ref: + if req_authn_context.comparison: + _cmp = req_authn_context.comparison + else: + _cmp = "minimum" + return self._pick_by_class_ref( + req_authn_context.authn_context_decl_ref, _cmp) def match(self, requested, provided): if requested == provided: @@ -206,4 +203,10 @@ def authn_context_decl_from_extension_elements(extelems): def authn_context_class_ref(ref): - return AuthnContext(authn_context_class_ref=AuthnContextClassRef(text=ref)) \ No newline at end of file + return AuthnContext(authn_context_class_ref=AuthnContextClassRef(text=ref)) + + +def requested_authn_context(class_ref, comparison="minimum"): + return RequestedAuthnContext( + authn_context_class_ref=AuthnContextClassRef(text=class_ref), + comparison=comparison) \ No newline at end of file diff --git a/tests/test_77_authn_context.py b/tests/test_77_authn_context.py index 97cce92..c98c3bc 100644 --- a/tests/test_77_authn_context.py +++ b/tests/test_77_authn_context.py @@ -15,7 +15,11 @@ ex1 = """