Fixed authn query request-response

This commit is contained in:
Roland Hedberg
2013-01-14 13:55:30 +01:00
parent 85b3fc307c
commit 3dff9e5245
7 changed files with 203 additions and 8 deletions

View File

@@ -492,13 +492,20 @@ class Base(Entity):
binding=BINDING_SOAP): binding=BINDING_SOAP):
""" Verify that the response is OK """ Verify that the response is OK
""" """
kwargs = {"entity_id": self.config.entityid,
"attribute_converters": self.config.attribute_converters}
return self._parse_response(response, AuthzResponse, "", binding) return self._parse_response(response, AuthzResponse, "", binding,
**kwargs)
def parse_authn_query_response(self, response, binding=BINDING_SOAP): def parse_authn_query_response(self, response, binding=BINDING_SOAP):
""" Verify that the response is OK """ Verify that the response is OK
""" """
return self._parse_response(response, AuthnQueryResponse, "", binding) kwargs = {"entity_id": self.config.entityid,
"attribute_converters": self.config.attribute_converters}
return self._parse_response(response, AuthnQueryResponse, "", binding,
**kwargs)
def parse_assertion_id_request_response(self, response, binding): def parse_assertion_id_request_response(self, response, binding):
""" Verify that the response is OK """ Verify that the response is OK

View File

@@ -157,6 +157,18 @@ class AuthnRequest(Request):
return to_local(self.attribute_converters, self.message) return to_local(self.attribute_converters, self.message)
class AuthnQuery(Request):
msgtype = "authn_query"
def __init__(self, sec_context, receiver_addrs, attribute_converters,
timeslack=0):
Request.__init__(self, sec_context, receiver_addrs,
attribute_converters, timeslack)
self.signature_check = self.sec.correctly_signed_authn_query
def attributes(self):
return to_local(self.attribute_converters, self.message)
class AssertionIDRequest(Request): class AssertionIDRequest(Request):
msgtype = "assertion_id_request" msgtype = "assertion_id_request"
def __init__(self, sec_context, receiver_addrs, attribute_converters, def __init__(self, sec_context, receiver_addrs, attribute_converters,

View File

@@ -670,6 +670,8 @@ class AuthnQueryResponse(AuthnResponse):
self.assertion = None self.assertion = None
self.context = "AuthnQueryResponse" self.context = "AuthnQueryResponse"
def condition_ok(self, lax=False): # Should I care about conditions ?
return True
class AttributeResponse(AuthnResponse): class AttributeResponse(AuthnResponse):
msgtype = "attribute_response" msgtype = "attribute_response"

View File

@@ -23,9 +23,7 @@ import logging
import shelve import shelve
import sys import sys
import memcache import memcache
from saml2.samlp import AuthzDecisionQuery
from saml2.samlp import NameIDMappingResponse from saml2.samlp import NameIDMappingResponse
from saml2.samlp import AuthnQuery
from saml2.entity import Entity from saml2.entity import Entity
from saml2 import saml from saml2 import saml
@@ -36,6 +34,8 @@ from saml2.request import AuthnRequest
from saml2.request import AssertionIDRequest from saml2.request import AssertionIDRequest
from saml2.request import AttributeQuery from saml2.request import AttributeQuery
from saml2.request import NameIDMappingRequest from saml2.request import NameIDMappingRequest
from saml2.request import AuthzDecisionQuery
from saml2.request import AuthnQuery
from saml2.s_utils import sid from saml2.s_utils import sid
from saml2.s_utils import MissingValue from saml2.s_utils import MissingValue
@@ -53,6 +53,9 @@ logger = logging.getLogger(__name__)
class UnknownVO(Exception): class UnknownVO(Exception):
pass pass
def context_match(cfilter, cntx):
return True
class Identifier(object): class Identifier(object):
""" A class that handles identifiers of objects """ """ A class that handles identifiers of objects """
def __init__(self, db, voconf=None): def __init__(self, db, voconf=None):
@@ -340,6 +343,34 @@ class Server(Entity):
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def store_assertion(self, assertion, to_sign):
self.assertion[assertion.id] = (assertion, to_sign)
def get_assertion(self, id):
return self.assertion[id]
def store_authn_statement(self, authn_statement, name_id):
try:
self.authn[name_id.text].append(authn_statement)
except:
self.authn[name_id.text] = [authn_statement]
def get_authn_statements(self, subject, session_index=None,
requested_context=None):
result = []
for statement in self.authn[subject.name_id.text]:
if session_index:
if statement.session_index != session_index:
continue
if requested_context:
if not context_match(requested_context, statement.authn_context):
continue
result.append(statement)
return result
# ------------------------------------------------------------------------
def _authn_response(self, in_response_to, consumer_url, def _authn_response(self, in_response_to, consumer_url,
sp_entity_id, identity=None, name_id=None, sp_entity_id, identity=None, name_id=None,
status=None, authn=None, status=None, authn=None,
@@ -384,12 +415,14 @@ class Server(Entity):
policy, issuer=_issuer, policy, issuer=_issuer,
authn_class=authn_class, authn_class=authn_class,
authn_auth=authn_authn) authn_auth=authn_authn)
self.store_authn_statement(assertion.authn_statement, name_id)
elif authn_decl: elif authn_decl:
assertion = ast.construct(sp_entity_id, in_response_to, assertion = ast.construct(sp_entity_id, in_response_to,
consumer_url, name_id, consumer_url, name_id,
self.config.attribute_converters, self.config.attribute_converters,
policy, issuer=_issuer, policy, issuer=_issuer,
authn_decl=authn_decl) authn_decl=authn_decl)
self.store_authn_statement(assertion.authn_statement, name_id)
else: else:
assertion = ast.construct(sp_entity_id, in_response_to, assertion = ast.construct(sp_entity_id, in_response_to,
consumer_url, name_id, consumer_url, name_id,
@@ -411,7 +444,7 @@ class Server(Entity):
args["assertion"] = assertion args["assertion"] = assertion
self.assertion[assertion.id] = (assertion, to_sign) self.store_assertion(assertion, to_sign)
return self._response(in_response_to, consumer_url, status, issuer, return self._response(in_response_to, consumer_url, status, issuer,
sign_response, to_sign, **args) sign_response, to_sign, **args)
@@ -578,7 +611,7 @@ class Server(Entity):
for aid in assertion_id: for aid in assertion_id:
try: try:
(assertion, to_sign) = self.assertion[aid] (assertion, to_sign) = self.get_assertion(aid)
to_sign.extend(to_sign) to_sign.extend(to_sign)
try: try:
args["assertion"].append(assertion) args["assertion"].append(assertion)
@@ -620,3 +653,29 @@ class Server(Entity):
logger.info("Message: %s" % _resp) logger.info("Message: %s" % _resp)
return _resp return _resp
def create_authn_query_response(self, subject, session_index=None,
requested_context=None, in_response_to=None,
issuer=None, sign_response=False,
status=None):
"""
A successful <Response> will contain one or more assertions containing
authentication statements.
:return:
"""
margs = self.message_args()
asserts = []
for statement in self.get_authn_statements(subject, session_index,
requested_context):
asserts.append(saml.Assertion(authn_statement=statement,
subject=subject, **margs))
if asserts:
args = {"assertion": asserts}
else:
args = {}
return self._response(in_response_to, "", status, issuer,
sign_response, to_sign=[], **args)

View File

@@ -765,6 +765,11 @@ class SecurityContext(object):
return self.correctly_signed_message(decoded_xml, "authn_request", return self.correctly_signed_message(decoded_xml, "authn_request",
must, origdoc) must, origdoc)
def correctly_signed_authn_query(self, decoded_xml, must=False,
origdoc=None):
return self.correctly_signed_message(decoded_xml, "authn_query",
must, origdoc)
def correctly_signed_logout_request(self, decoded_xml, must=False, def correctly_signed_logout_request(self, decoded_xml, must=False,
origdoc=None): origdoc=None):
return self.correctly_signed_message(decoded_xml, "logout_request", return self.correctly_signed_message(decoded_xml, "logout_request",

View File

@@ -92,6 +92,15 @@ def parse_soap_enveloped_saml_assertion_id_response(text):
'{%s}AssertionIDResponse' % SAMLP_NAMESPACE] '{%s}AssertionIDResponse' % SAMLP_NAMESPACE]
return parse_soap_enveloped_saml_thingy(text, tags) return parse_soap_enveloped_saml_thingy(text, tags)
def parse_soap_enveloped_saml_authn_query(text):
expected_tag = '{%s}AuthnQuery' % SAMLP_NAMESPACE
return parse_soap_enveloped_saml_thingy(text, [expected_tag])
def parse_soap_enveloped_saml_authn_query_response(text):
tags = ['{%s}Response' % SAMLP_NAMESPACE]
return parse_soap_enveloped_saml_thingy(text, tags)
#def parse_soap_enveloped_saml_logout_response(text): #def parse_soap_enveloped_saml_logout_response(text):
# expected_tag = '{%s}LogoutResponse' % SAMLP_NAMESPACE # expected_tag = '{%s}LogoutResponse' % SAMLP_NAMESPACE
# return parse_soap_enveloped_saml_thingy(text, [expected_tag]) # return parse_soap_enveloped_saml_thingy(text, [expected_tag])

View File

@@ -1,7 +1,9 @@
from urlparse import urlparse, parse_qs
from saml2 import BINDING_SOAP, BINDING_HTTP_POST
__author__ = 'rolandh' __author__ = 'rolandh'
from saml2.samlp import RequestedAuthnContext from saml2.samlp import RequestedAuthnContext, AuthnRequest, NameIDPolicy
from saml2.samlp import AuthnQuery from saml2.samlp import AuthnQuery
from saml2.client import Saml2Client from saml2.client import Saml2Client
from saml2.saml import AUTHN_PASSWORD from saml2.saml import AUTHN_PASSWORD
@@ -11,6 +13,25 @@ from saml2.saml import NameID
from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.saml import NAMEID_FORMAT_TRANSIENT
from saml2.server import Server from saml2.server import Server
TAG1 = "name=\"SAMLRequest\" value="
def get_msg(hinfo, binding):
if binding == BINDING_SOAP:
xmlstr = hinfo["data"]
elif binding == BINDING_HTTP_POST:
_inp = hinfo["data"][3]
i = _inp.find(TAG1)
i += len(TAG1) + 1
j = _inp.find('"', i)
xmlstr = _inp[i:j]
else: # BINDING_HTTP_REDIRECT
parts = urlparse(hinfo["headers"][0][1])
xmlstr = parse_qs(parts.query)["SAMLRequest"][0]
return xmlstr
# ------------------------------------------------------------------------
def test_basic(): def test_basic():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf") idp = Server(config_file="idp_all_conf")
@@ -29,3 +50,83 @@ def test_basic():
print aq print aq
assert isinstance(aq, AuthnQuery) assert isinstance(aq, AuthnQuery)
def test_flow():
sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf")
relay_state = "FOO"
# -- dummy request ---
orig_req = AuthnRequest(issuer=sp._issuer(),
name_id_policy=NameIDPolicy(allow_create="true",
format=NAMEID_FORMAT_TRANSIENT))
# == Create an AuthnRequest response
name_id = idp.ident.transient_nameid(sp.config.entityid, "id12")
binding, destination = idp.pick_binding("assertion_consumer_service",
entity_id=sp.config.entityid)
resp = idp.create_authn_response({"eduPersonEntitlement": "Short stop",
"surName": "Jeter",
"givenName": "Derek",
"mail": "derek.jeter@nyy.mlb.com",
"title": "The man"},
"id-123456789",
destination,
sp.config.entityid,
name_id=name_id,
authn=(AUTHN_PASSWORD,
"http://www.example.com/login"))
hinfo = idp.apply_binding(binding, "%s" % resp, destination, relay_state)
# ------- @SP ----------
xmlstr = get_msg(hinfo, binding)
aresp = sp.parse_authn_request_response(xmlstr, binding,
{resp.in_response_to :"/"})
binding, destination = sp.pick_binding("authn_query_service",
entity_id=idp.config.entityid)
authn_context = [RequestedAuthnContext(
authn_context_class_ref=AuthnContextClassRef(
text=AUTHN_PASSWORD))]
subject = aresp.assertion.subject
aq = sp.create_authn_query(subject, destination, authn_context)
print aq
assert isinstance(aq, AuthnQuery)
binding = BINDING_SOAP
hinfo = sp.apply_binding(binding, "%s" % aq, destination, "state2")
# -------- @IDP ----------
xmlstr = get_msg(hinfo, binding)
pm = idp.parse_authn_query(xmlstr, binding)
msg = pm.message
assert msg.id == aq.id
p_res = idp.create_authn_query_response(msg.subject, msg.session_index,
msg.requested_authn_context)
print p_res
hinfo = idp.apply_binding(binding, "%s" % p_res, "", "state2", "SAMLResponse")
# ------- @SP ----------
xmlstr = get_msg(hinfo, binding)
final = sp.parse_authn_query_response(xmlstr, binding)
print final
assert final.response.id == p_res.id