From 59326927262fd33bfed69d827f4d2f0a507acf50 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Wed, 27 Apr 2011 14:15:51 +0200 Subject: [PATCH] Made AttributeResponse subclass of AuthResponse. Specified different processes depending on asynchronous or synchronous (SOAP) operation. Use the original XML document when checking signature --- src/saml2/response.py | 128 +++++++++++++++++++++++++++--------------- 1 file changed, 82 insertions(+), 46 deletions(-) diff --git a/src/saml2/response.py b/src/saml2/response.py index e2b152b..556a4f2 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright (C) 2010 UmeƄ University @@ -80,7 +80,8 @@ def attribute_response(conf, return_addr, log=None, timeslack=0, timeslack = 0 return AttributeResponse(sec, conf.attribute_converters, conf.entityid, - return_addr, log, timeslack, debug) + return_addr, log, timeslack, debug, + asynchop=False) class StatusResponse(object): def __init__(self, sec_context, return_addr=None, log=None, timeslack=0, @@ -136,7 +137,7 @@ class StatusResponse(object): self.response = self.sec.check_signature(instance) return self._postamble() - def _loads(self, xmldata, decode=True): + def _loads(self, xmldata, decode=True, origxml=None): if decode: decoded_xml = base64.b64decode(xmldata) else: @@ -146,9 +147,12 @@ class StatusResponse(object): self.xmlstr = decoded_xml[:] if self.debug: self.log.info("xmlstr: %s" % (self.xmlstr,)) + fil = open("response.xml", "w") + fil.write(self.xmlstr) + fil.close() try: - self.response = self.signature_check(decoded_xml) + self.response = self.signature_check(decoded_xml, origdoc=origxml) except TypeError: raise except SignatureError: @@ -207,8 +211,8 @@ class StatusResponse(object): assert self.status_ok() return self - def loads(self, xmldata, decode=True): - return self._loads(xmldata, decode) + def loads(self, xmldata, decode=True, origxml=None): + return self._loads(xmldata, decode, origxml) def verify(self): try: @@ -231,41 +235,56 @@ class LogoutResponse(StatusResponse): debug) self.signature_check = self.sec.correctly_signed_logout_response -class AttributeResponse(StatusResponse): - def __init__(self, sec_context, attribute_converters, entity_id, - return_addr=None, log=None, timeslack=0, debug=0): - StatusResponse.__init__(self, sec_context, return_addr, log, timeslack, - debug) - self.entity_id = entity_id - self.attribute_converters = attribute_converters - self.assertion = None - - def get_identity(self): - # The assertion can contain zero or one attributeStatements - if not self.assertion.attribute_statement: - self.log.error("Missing Attribute Statement") - ava = {} - else: - assert len(self.assertion.attribute_statement) == 1 - - if self.debug: - self.log.info("Attribute Statement: %s" % ( - self.assertion.attribute_statement[0],)) - for aconv in self.attribute_converters: - self.log.info( - "Converts name format: %s" % (aconv.name_format,)) - - ava = to_local(self.attribute_converters, - self.assertion.attribute_statement[0]) - return ava - +#class AttributeResponse(StatusResponse): +# def __init__(self, sec_context, attribute_converters, entity_id, +# return_addr=None, log=None, timeslack=0, debug=0): +# StatusResponse.__init__(self, sec_context, return_addr, log, timeslack, +# debug) +# self.entity_id = entity_id +# self.attribute_converters = attribute_converters +# self.assertion = None +# +# def get_identity(self): +# # The assertion can contain zero or one attributeStatements +# if not self.assertion.attribute_statement: +# self.log.error("Missing Attribute Statement") +# ava = {} +# else: +# assert len(self.assertion.attribute_statement) == 1 +# +# if self.debug: +# self.log.info("Attribute Statement: %s" % ( +# self.assertion.attribute_statement[0],)) +# for aconv in self.attribute_converters: +# self.log.info( +# "Converts name format: %s" % (aconv.name_format,)) +# +# ava = to_local(self.attribute_converters, +# self.assertion.attribute_statement[0]) +# return ava +# +# def session_info(self): +# """ Returns a predefined set of information gleened from the +# response. +# :returns: Dictionary with information +# """ +# if self.session_not_on_or_after > 0: +# nooa = self.session_not_on_or_after +# else: +# nooa = self.not_on_or_after +# +# return { "ava": self.ava, "name_id": self.name_id, +# "came_from": self.came_from, "issuer": self.issuer(), +# "not_on_or_after": nooa, +# "authn_info": self.authn_info() } + class AuthnResponse(StatusResponse): """ This is where all the profile complience is checked. This one does saml2int complience. """ def __init__(self, sec_context, attribute_converters, entity_id, return_addr=None, outstanding_queries=None, log=None, - timeslack=0, debug=0): + timeslack=0, debug=0, asynchop=False): StatusResponse.__init__(self, sec_context, return_addr, log, timeslack, debug) self.entity_id = entity_id @@ -279,11 +298,12 @@ class AuthnResponse(StatusResponse): self.ava = None self.assertion = None self.session_not_on_or_after = 0 - - def loads(self, xmldata, decode=True): - self._loads(xmldata, decode) + self.asynchop = asynchop + + def loads(self, xmldata, decode=True, origxml=None): + self._loads(xmldata, decode, origxml) - if self.in_response_to in self.outstanding_queries: + if self.asynchop and self.in_response_to in self.outstanding_queries: self.came_from = self.outstanding_queries[self.in_response_to] del self.outstanding_queries[self.in_response_to] @@ -359,7 +379,11 @@ class AuthnResponse(StatusResponse): return ava def get_subject(self): - # The assertion must contain a Subject + """ The assertion must contain a Subject + + :param asynch: If the connection is asynchronous there is + outstanding queries to connect to + """ assert self.assertion.subject subject = self.assertion.subject subjconf = [] @@ -382,7 +406,7 @@ class AuthnResponse(StatusResponse): if not time_util.later_than(data.not_on_or_after, data.not_before): continue - if not self.came_from: + if self.asynchop and not self.came_from: if data.in_response_to in self.outstanding_queries: self.came_from = self.outstanding_queries[ data.in_response_to] @@ -420,7 +444,7 @@ class AuthnResponse(StatusResponse): self.log.info("outstanding_queries: %s" % ( self.outstanding_queries,)) - if self.context == "AuthNReq": + if self.context == "AuthnReq": self.authn_statement_ok() if not self.condition_ok(): @@ -436,7 +460,7 @@ class AuthnResponse(StatusResponse): try: self.get_subject() - if not self.came_from: + if self.asynchop and not self.came_from: return False else: return True @@ -533,9 +557,21 @@ class AuthnResponse(StatusResponse): def __str__(self): return "%s" % self.xmlstr +class AttributeResponse(AuthnResponse): + def __init__(self, sec_context, attribute_converters, entity_id, + return_addr=None, log=None, timeslack=0, debug=0, + asynchop=True): + AuthnResponse.__init__(self, sec_context, return_addr, log, timeslack, + debug, asynchop) + self.entity_id = entity_id + self.attribute_converters = attribute_converters + self.assertion = None + self.context = "AttrQuery" + def response_factory(xmlstr, conf, return_addr=None, outstanding_queries=None, log=None, - timeslack=0, debug=0, decode=True, request_id=0): + timeslack=0, debug=0, decode=True, request_id=0, + origxml=None): sec_context = security_context(conf) if not timeslack: try: @@ -549,7 +585,7 @@ def response_factory(xmlstr, conf, return_addr=None, response = StatusResponse(sec_context, return_addr, log, timeslack, debug, request_id) try: - response.loads(xmlstr, decode) + response.loads(xmlstr, decode, origxml) if response.response.assertion: authnresp = AuthnResponse(sec_context, attribute_converters, entity_id, return_addr, outstanding_queries, log, @@ -558,7 +594,7 @@ def response_factory(xmlstr, conf, return_addr=None, return authnresp except TypeError: response.signature_check = sec_context.correctly_signed_logout_response - response.loads(xmlstr, decode) + response.loads(xmlstr, decode, origxml) logoutresp = LogoutResponse(sec_context, return_addr, log, timeslack, debug) logoutresp.update(response)