Made AttributeResponse subclass of AuthResponse. Specified different processes depending on asynchronous or synchronous (SOAP) operation. Use the original XML document when checking signature

This commit is contained in:
Roland Hedberg
2011-04-27 14:15:51 +02:00
parent 937cabd050
commit 5932692726

View File

@@ -1,4 +1,4 @@
#!/usr/bin/python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Umeå University
@@ -80,7 +80,8 @@ def attribute_response(conf, return_addr, log=None, timeslack=0,
timeslack = 0
return AttributeResponse(sec, conf.attribute_converters, conf.entityid,
return_addr, log, timeslack, debug)
return_addr, log, timeslack, debug,
asynchop=False)
class StatusResponse(object):
def __init__(self, sec_context, return_addr=None, log=None, timeslack=0,
@@ -136,7 +137,7 @@ class StatusResponse(object):
self.response = self.sec.check_signature(instance)
return self._postamble()
def _loads(self, xmldata, decode=True):
def _loads(self, xmldata, decode=True, origxml=None):
if decode:
decoded_xml = base64.b64decode(xmldata)
else:
@@ -146,9 +147,12 @@ class StatusResponse(object):
self.xmlstr = decoded_xml[:]
if self.debug:
self.log.info("xmlstr: %s" % (self.xmlstr,))
fil = open("response.xml", "w")
fil.write(self.xmlstr)
fil.close()
try:
self.response = self.signature_check(decoded_xml)
self.response = self.signature_check(decoded_xml, origdoc=origxml)
except TypeError:
raise
except SignatureError:
@@ -207,8 +211,8 @@ class StatusResponse(object):
assert self.status_ok()
return self
def loads(self, xmldata, decode=True):
return self._loads(xmldata, decode)
def loads(self, xmldata, decode=True, origxml=None):
return self._loads(xmldata, decode, origxml)
def verify(self):
try:
@@ -231,41 +235,56 @@ class LogoutResponse(StatusResponse):
debug)
self.signature_check = self.sec.correctly_signed_logout_response
class AttributeResponse(StatusResponse):
def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, log=None, timeslack=0, debug=0):
StatusResponse.__init__(self, sec_context, return_addr, log, timeslack,
debug)
self.entity_id = entity_id
self.attribute_converters = attribute_converters
self.assertion = None
def get_identity(self):
# The assertion can contain zero or one attributeStatements
if not self.assertion.attribute_statement:
self.log.error("Missing Attribute Statement")
ava = {}
else:
assert len(self.assertion.attribute_statement) == 1
if self.debug:
self.log.info("Attribute Statement: %s" % (
self.assertion.attribute_statement[0],))
for aconv in self.attribute_converters:
self.log.info(
"Converts name format: %s" % (aconv.name_format,))
ava = to_local(self.attribute_converters,
self.assertion.attribute_statement[0])
return ava
#class AttributeResponse(StatusResponse):
# def __init__(self, sec_context, attribute_converters, entity_id,
# return_addr=None, log=None, timeslack=0, debug=0):
# StatusResponse.__init__(self, sec_context, return_addr, log, timeslack,
# debug)
# self.entity_id = entity_id
# self.attribute_converters = attribute_converters
# self.assertion = None
#
# def get_identity(self):
# # The assertion can contain zero or one attributeStatements
# if not self.assertion.attribute_statement:
# self.log.error("Missing Attribute Statement")
# ava = {}
# else:
# assert len(self.assertion.attribute_statement) == 1
#
# if self.debug:
# self.log.info("Attribute Statement: %s" % (
# self.assertion.attribute_statement[0],))
# for aconv in self.attribute_converters:
# self.log.info(
# "Converts name format: %s" % (aconv.name_format,))
#
# ava = to_local(self.attribute_converters,
# self.assertion.attribute_statement[0])
# return ava
#
# def session_info(self):
# """ Returns a predefined set of information gleened from the
# response.
# :returns: Dictionary with information
# """
# if self.session_not_on_or_after > 0:
# nooa = self.session_not_on_or_after
# else:
# nooa = self.not_on_or_after
#
# return { "ava": self.ava, "name_id": self.name_id,
# "came_from": self.came_from, "issuer": self.issuer(),
# "not_on_or_after": nooa,
# "authn_info": self.authn_info() }
class AuthnResponse(StatusResponse):
""" This is where all the profile complience is checked.
This one does saml2int complience. """
def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, outstanding_queries=None, log=None,
timeslack=0, debug=0):
timeslack=0, debug=0, asynchop=False):
StatusResponse.__init__(self, sec_context, return_addr, log,
timeslack, debug)
self.entity_id = entity_id
@@ -279,11 +298,12 @@ class AuthnResponse(StatusResponse):
self.ava = None
self.assertion = None
self.session_not_on_or_after = 0
def loads(self, xmldata, decode=True):
self._loads(xmldata, decode)
self.asynchop = asynchop
def loads(self, xmldata, decode=True, origxml=None):
self._loads(xmldata, decode, origxml)
if self.in_response_to in self.outstanding_queries:
if self.asynchop and self.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[self.in_response_to]
del self.outstanding_queries[self.in_response_to]
@@ -359,7 +379,11 @@ class AuthnResponse(StatusResponse):
return ava
def get_subject(self):
# The assertion must contain a Subject
""" The assertion must contain a Subject
:param asynch: If the connection is asynchronous there is
outstanding queries to connect to
"""
assert self.assertion.subject
subject = self.assertion.subject
subjconf = []
@@ -382,7 +406,7 @@ class AuthnResponse(StatusResponse):
if not time_util.later_than(data.not_on_or_after, data.not_before):
continue
if not self.came_from:
if self.asynchop and not self.came_from:
if data.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[
data.in_response_to]
@@ -420,7 +444,7 @@ class AuthnResponse(StatusResponse):
self.log.info("outstanding_queries: %s" % (
self.outstanding_queries,))
if self.context == "AuthNReq":
if self.context == "AuthnReq":
self.authn_statement_ok()
if not self.condition_ok():
@@ -436,7 +460,7 @@ class AuthnResponse(StatusResponse):
try:
self.get_subject()
if not self.came_from:
if self.asynchop and not self.came_from:
return False
else:
return True
@@ -533,9 +557,21 @@ class AuthnResponse(StatusResponse):
def __str__(self):
return "%s" % self.xmlstr
class AttributeResponse(AuthnResponse):
def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, log=None, timeslack=0, debug=0,
asynchop=True):
AuthnResponse.__init__(self, sec_context, return_addr, log, timeslack,
debug, asynchop)
self.entity_id = entity_id
self.attribute_converters = attribute_converters
self.assertion = None
self.context = "AttrQuery"
def response_factory(xmlstr, conf, return_addr=None,
outstanding_queries=None, log=None,
timeslack=0, debug=0, decode=True, request_id=0):
timeslack=0, debug=0, decode=True, request_id=0,
origxml=None):
sec_context = security_context(conf)
if not timeslack:
try:
@@ -549,7 +585,7 @@ def response_factory(xmlstr, conf, return_addr=None,
response = StatusResponse(sec_context, return_addr, log, timeslack,
debug, request_id)
try:
response.loads(xmlstr, decode)
response.loads(xmlstr, decode, origxml)
if response.response.assertion:
authnresp = AuthnResponse(sec_context, attribute_converters,
entity_id, return_addr, outstanding_queries, log,
@@ -558,7 +594,7 @@ def response_factory(xmlstr, conf, return_addr=None,
return authnresp
except TypeError:
response.signature_check = sec_context.correctly_signed_logout_response
response.loads(xmlstr, decode)
response.loads(xmlstr, decode, origxml)
logoutresp = LogoutResponse(sec_context, return_addr, log,
timeslack, debug)
logoutresp.update(response)