diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index 4320e62..7d13ce6 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -787,7 +787,7 @@ def extension_element_to_element(extension_element, translation_functions, element it is. Or rather which module it belongs to. :param extension_element: The extension element - :prama translation_functions: A dictionary which klass identifiers + :param translation_functions: A dictionary with class identifiers as keys and string-to-element translations functions as values :param namespace: The namespace of the translation functions. :return: An element instance or None diff --git a/src/saml2/authn_context/__init__.py b/src/saml2/authn_context/__init__.py index b8f11f6..a8432ed 100644 --- a/src/saml2/authn_context/__init__.py +++ b/src/saml2/authn_context/__init__.py @@ -1,5 +1,7 @@ __author__ = 'rolandh' +from saml2 import extension_elements_to_elements + INTERNETPROTOCOLPASSWORD = \ 'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword' MOBILETWOFACTORCONTRACT = \ @@ -52,7 +54,8 @@ class Authn(object): authentication context is defined find out where to send the user next. :param endpoint: The service endpoint URL - :param authn_context: An AuthnContext instance + :param req_authn_context: The requested context as an AuthnContext + instance :return: An URL """ @@ -66,8 +69,8 @@ class Authn(object): return _endpspec[req_authn_context.authn_context_class_ref.text] elif req_authn_context.authn_context_decl: key = req_authn_context.authn_context_decl.c_namespace - for spec, target in _endpspec[key]: - if self.match(req_authn_context, spec): + for acd, target in _endpspec[key]: + if self.match(req_authn_context.authn_context_decl, acd): return target def match(self, requested, provided): @@ -84,4 +87,12 @@ def authn_context_factory(text): if inst: return inst - return None \ No newline at end of file + return None + +def authn_context_decl_from_extension_elements(extelems): + res = extension_elements_to_elements(extelems, [ippword, mobiletwofactor, + ppt, pword, sslcert]) + try: + return res[0] + except IndexError: + return None \ No newline at end of file diff --git a/tests/test_77_authn_context.py b/tests/test_77_authn_context.py index faefe5c..db16323 100644 --- a/tests/test_77_authn_context.py +++ b/tests/test_77_authn_context.py @@ -1,3 +1,7 @@ +from saml2.saml import AuthnContext +from saml2.saml import authn_context_from_string +from saml2.saml import AuthnContextClassRef + __author__ = 'rolandh' ex1 = """ """ -from saml2.authn_context import pword +from saml2.authn_context import pword, PASSWORDPROTECTEDTRANSPORT +from saml2.authn_context import Authn +from saml2.authn_context import authn_context_decl_from_extension_elements from saml2.authn_context import authn_context_factory -def test_passwd(): - length = pword.Length(min="4") - restricted_password = pword.RestrictedPassword(length=length) - authenticator = pword.Authenticator(restricted_password=restricted_password) - authn_method = pword.AuthnMethod(authenticator=authenticator) - inst = pword.AuthenticationContextDeclaration(authn_method=authn_method) +length = pword.Length(min="4") +restricted_password = pword.RestrictedPassword(length=length) +authenticator = pword.Authenticator(restricted_password=restricted_password) +authn_method = pword.AuthnMethod(authenticator=authenticator) +ACD = pword.AuthenticationContextDeclaration(authn_method=authn_method) +AUTHNCTXT = AuthnContext(authn_context_decl=ACD) + + +def test_passwd(): + inst = ACD inst2 = pword.authentication_context_declaration_from_string(ex1) assert inst == inst2 @@ -32,5 +42,38 @@ def test_factory(): assert inst_pw == inst + +def test_authn_decl_in_authn_context(): + authnctxt = AuthnContext(authn_context_decl=ACD) + + acs = authn_context_from_string("%s" % authnctxt) + if acs.extension_elements: + cacd = authn_context_decl_from_extension_elements( + acs.extension_elements) + if cacd: + acs.authn_context_decl = cacd + + assert acs.authn_context_decl == ACD + + +def test_authn_1(): + accr = AuthnContextClassRef(text=PASSWORDPROTECTEDTRANSPORT) + ac = AuthnContext(authn_context_class_ref=accr) + authn = Authn() + target = "https://example.org/login" + endpoint = "https://example.com/sso/redirect" + authn.add(endpoint, ac, target) + + assert target == authn.pick(endpoint, ac) + + +def test_authn_2(): + authn = Authn() + target = "https://example.org/login" + endpoint = "https://example.com/sso/redirect" + authn.add(endpoint, AUTHNCTXT, target) + + assert target == authn.pick(endpoint, AUTHNCTXT) + if __name__ == "__main__": - test_factory() \ No newline at end of file + test_authn_2() \ No newline at end of file