Made AttributeResponse subclass of AuthResponse. Specified different processes depending on asynchronous or synchronous (SOAP) operation. Use the original XML document when checking signature

This commit is contained in:
Roland Hedberg
2011-04-27 14:15:51 +02:00
parent 937cabd050
commit 5932692726

View File

@@ -1,4 +1,4 @@
#!/usr/bin/python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (C) 2010 Umeå University # Copyright (C) 2010 Umeå University
@@ -80,7 +80,8 @@ def attribute_response(conf, return_addr, log=None, timeslack=0,
timeslack = 0 timeslack = 0
return AttributeResponse(sec, conf.attribute_converters, conf.entityid, return AttributeResponse(sec, conf.attribute_converters, conf.entityid,
return_addr, log, timeslack, debug) return_addr, log, timeslack, debug,
asynchop=False)
class StatusResponse(object): class StatusResponse(object):
def __init__(self, sec_context, return_addr=None, log=None, timeslack=0, def __init__(self, sec_context, return_addr=None, log=None, timeslack=0,
@@ -136,7 +137,7 @@ class StatusResponse(object):
self.response = self.sec.check_signature(instance) self.response = self.sec.check_signature(instance)
return self._postamble() return self._postamble()
def _loads(self, xmldata, decode=True): def _loads(self, xmldata, decode=True, origxml=None):
if decode: if decode:
decoded_xml = base64.b64decode(xmldata) decoded_xml = base64.b64decode(xmldata)
else: else:
@@ -146,9 +147,12 @@ class StatusResponse(object):
self.xmlstr = decoded_xml[:] self.xmlstr = decoded_xml[:]
if self.debug: if self.debug:
self.log.info("xmlstr: %s" % (self.xmlstr,)) self.log.info("xmlstr: %s" % (self.xmlstr,))
fil = open("response.xml", "w")
fil.write(self.xmlstr)
fil.close()
try: try:
self.response = self.signature_check(decoded_xml) self.response = self.signature_check(decoded_xml, origdoc=origxml)
except TypeError: except TypeError:
raise raise
except SignatureError: except SignatureError:
@@ -207,8 +211,8 @@ class StatusResponse(object):
assert self.status_ok() assert self.status_ok()
return self return self
def loads(self, xmldata, decode=True): def loads(self, xmldata, decode=True, origxml=None):
return self._loads(xmldata, decode) return self._loads(xmldata, decode, origxml)
def verify(self): def verify(self):
try: try:
@@ -231,33 +235,48 @@ class LogoutResponse(StatusResponse):
debug) debug)
self.signature_check = self.sec.correctly_signed_logout_response self.signature_check = self.sec.correctly_signed_logout_response
class AttributeResponse(StatusResponse): #class AttributeResponse(StatusResponse):
def __init__(self, sec_context, attribute_converters, entity_id, # def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, log=None, timeslack=0, debug=0): # return_addr=None, log=None, timeslack=0, debug=0):
StatusResponse.__init__(self, sec_context, return_addr, log, timeslack, # StatusResponse.__init__(self, sec_context, return_addr, log, timeslack,
debug) # debug)
self.entity_id = entity_id # self.entity_id = entity_id
self.attribute_converters = attribute_converters # self.attribute_converters = attribute_converters
self.assertion = None # self.assertion = None
#
def get_identity(self): # def get_identity(self):
# The assertion can contain zero or one attributeStatements # # The assertion can contain zero or one attributeStatements
if not self.assertion.attribute_statement: # if not self.assertion.attribute_statement:
self.log.error("Missing Attribute Statement") # self.log.error("Missing Attribute Statement")
ava = {} # ava = {}
else: # else:
assert len(self.assertion.attribute_statement) == 1 # assert len(self.assertion.attribute_statement) == 1
#
if self.debug: # if self.debug:
self.log.info("Attribute Statement: %s" % ( # self.log.info("Attribute Statement: %s" % (
self.assertion.attribute_statement[0],)) # self.assertion.attribute_statement[0],))
for aconv in self.attribute_converters: # for aconv in self.attribute_converters:
self.log.info( # self.log.info(
"Converts name format: %s" % (aconv.name_format,)) # "Converts name format: %s" % (aconv.name_format,))
#
ava = to_local(self.attribute_converters, # ava = to_local(self.attribute_converters,
self.assertion.attribute_statement[0]) # self.assertion.attribute_statement[0])
return ava # return ava
#
# def session_info(self):
# """ Returns a predefined set of information gleened from the
# response.
# :returns: Dictionary with information
# """
# if self.session_not_on_or_after > 0:
# nooa = self.session_not_on_or_after
# else:
# nooa = self.not_on_or_after
#
# return { "ava": self.ava, "name_id": self.name_id,
# "came_from": self.came_from, "issuer": self.issuer(),
# "not_on_or_after": nooa,
# "authn_info": self.authn_info() }
class AuthnResponse(StatusResponse): class AuthnResponse(StatusResponse):
""" This is where all the profile complience is checked. """ This is where all the profile complience is checked.
@@ -265,7 +284,7 @@ class AuthnResponse(StatusResponse):
def __init__(self, sec_context, attribute_converters, entity_id, def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, outstanding_queries=None, log=None, return_addr=None, outstanding_queries=None, log=None,
timeslack=0, debug=0): timeslack=0, debug=0, asynchop=False):
StatusResponse.__init__(self, sec_context, return_addr, log, StatusResponse.__init__(self, sec_context, return_addr, log,
timeslack, debug) timeslack, debug)
self.entity_id = entity_id self.entity_id = entity_id
@@ -279,11 +298,12 @@ class AuthnResponse(StatusResponse):
self.ava = None self.ava = None
self.assertion = None self.assertion = None
self.session_not_on_or_after = 0 self.session_not_on_or_after = 0
self.asynchop = asynchop
def loads(self, xmldata, decode=True): def loads(self, xmldata, decode=True, origxml=None):
self._loads(xmldata, decode) self._loads(xmldata, decode, origxml)
if self.in_response_to in self.outstanding_queries: if self.asynchop and self.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[self.in_response_to] self.came_from = self.outstanding_queries[self.in_response_to]
del self.outstanding_queries[self.in_response_to] del self.outstanding_queries[self.in_response_to]
@@ -359,7 +379,11 @@ class AuthnResponse(StatusResponse):
return ava return ava
def get_subject(self): def get_subject(self):
# The assertion must contain a Subject """ The assertion must contain a Subject
:param asynch: If the connection is asynchronous there is
outstanding queries to connect to
"""
assert self.assertion.subject assert self.assertion.subject
subject = self.assertion.subject subject = self.assertion.subject
subjconf = [] subjconf = []
@@ -382,7 +406,7 @@ class AuthnResponse(StatusResponse):
if not time_util.later_than(data.not_on_or_after, data.not_before): if not time_util.later_than(data.not_on_or_after, data.not_before):
continue continue
if not self.came_from: if self.asynchop and not self.came_from:
if data.in_response_to in self.outstanding_queries: if data.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[ self.came_from = self.outstanding_queries[
data.in_response_to] data.in_response_to]
@@ -420,7 +444,7 @@ class AuthnResponse(StatusResponse):
self.log.info("outstanding_queries: %s" % ( self.log.info("outstanding_queries: %s" % (
self.outstanding_queries,)) self.outstanding_queries,))
if self.context == "AuthNReq": if self.context == "AuthnReq":
self.authn_statement_ok() self.authn_statement_ok()
if not self.condition_ok(): if not self.condition_ok():
@@ -436,7 +460,7 @@ class AuthnResponse(StatusResponse):
try: try:
self.get_subject() self.get_subject()
if not self.came_from: if self.asynchop and not self.came_from:
return False return False
else: else:
return True return True
@@ -533,9 +557,21 @@ class AuthnResponse(StatusResponse):
def __str__(self): def __str__(self):
return "%s" % self.xmlstr return "%s" % self.xmlstr
class AttributeResponse(AuthnResponse):
def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, log=None, timeslack=0, debug=0,
asynchop=True):
AuthnResponse.__init__(self, sec_context, return_addr, log, timeslack,
debug, asynchop)
self.entity_id = entity_id
self.attribute_converters = attribute_converters
self.assertion = None
self.context = "AttrQuery"
def response_factory(xmlstr, conf, return_addr=None, def response_factory(xmlstr, conf, return_addr=None,
outstanding_queries=None, log=None, outstanding_queries=None, log=None,
timeslack=0, debug=0, decode=True, request_id=0): timeslack=0, debug=0, decode=True, request_id=0,
origxml=None):
sec_context = security_context(conf) sec_context = security_context(conf)
if not timeslack: if not timeslack:
try: try:
@@ -549,7 +585,7 @@ def response_factory(xmlstr, conf, return_addr=None,
response = StatusResponse(sec_context, return_addr, log, timeslack, response = StatusResponse(sec_context, return_addr, log, timeslack,
debug, request_id) debug, request_id)
try: try:
response.loads(xmlstr, decode) response.loads(xmlstr, decode, origxml)
if response.response.assertion: if response.response.assertion:
authnresp = AuthnResponse(sec_context, attribute_converters, authnresp = AuthnResponse(sec_context, attribute_converters,
entity_id, return_addr, outstanding_queries, log, entity_id, return_addr, outstanding_queries, log,
@@ -558,7 +594,7 @@ def response_factory(xmlstr, conf, return_addr=None,
return authnresp return authnresp
except TypeError: except TypeError:
response.signature_check = sec_context.correctly_signed_logout_response response.signature_check = sec_context.correctly_signed_logout_response
response.loads(xmlstr, decode) response.loads(xmlstr, decode, origxml)
logoutresp = LogoutResponse(sec_context, return_addr, log, logoutresp = LogoutResponse(sec_context, return_addr, log,
timeslack, debug) timeslack, debug)
logoutresp.update(response) logoutresp.update(response)