Made AttributeResponse subclass of AuthResponse. Specified different processes depending on asynchronous or synchronous (SOAP) operation. Use the original XML document when checking signature

This commit is contained in:
Roland Hedberg
2011-04-27 14:15:51 +02:00
parent 937cabd050
commit 5932692726

View File

@@ -1,4 +1,4 @@
#!/usr/bin/python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Copyright (C) 2010 Umeå University # Copyright (C) 2010 Umeå University
@@ -80,7 +80,8 @@ def attribute_response(conf, return_addr, log=None, timeslack=0,
timeslack = 0 timeslack = 0
return AttributeResponse(sec, conf.attribute_converters, conf.entityid, return AttributeResponse(sec, conf.attribute_converters, conf.entityid,
return_addr, log, timeslack, debug) return_addr, log, timeslack, debug,
asynchop=False)
class StatusResponse(object): class StatusResponse(object):
def __init__(self, sec_context, return_addr=None, log=None, timeslack=0, def __init__(self, sec_context, return_addr=None, log=None, timeslack=0,
@@ -136,7 +137,7 @@ class StatusResponse(object):
self.response = self.sec.check_signature(instance) self.response = self.sec.check_signature(instance)
return self._postamble() return self._postamble()
def _loads(self, xmldata, decode=True): def _loads(self, xmldata, decode=True, origxml=None):
if decode: if decode:
decoded_xml = base64.b64decode(xmldata) decoded_xml = base64.b64decode(xmldata)
else: else:
@@ -146,9 +147,12 @@ class StatusResponse(object):
self.xmlstr = decoded_xml[:] self.xmlstr = decoded_xml[:]
if self.debug: if self.debug:
self.log.info("xmlstr: %s" % (self.xmlstr,)) self.log.info("xmlstr: %s" % (self.xmlstr,))
fil = open("response.xml", "w")
fil.write(self.xmlstr)
fil.close()
try: try:
self.response = self.signature_check(decoded_xml) self.response = self.signature_check(decoded_xml, origdoc=origxml)
except TypeError: except TypeError:
raise raise
except SignatureError: except SignatureError:
@@ -207,8 +211,8 @@ class StatusResponse(object):
assert self.status_ok() assert self.status_ok()
return self return self
def loads(self, xmldata, decode=True): def loads(self, xmldata, decode=True, origxml=None):
return self._loads(xmldata, decode) return self._loads(xmldata, decode, origxml)
def verify(self): def verify(self):
try: try:
@@ -231,41 +235,56 @@ class LogoutResponse(StatusResponse):
debug) debug)
self.signature_check = self.sec.correctly_signed_logout_response self.signature_check = self.sec.correctly_signed_logout_response
class AttributeResponse(StatusResponse): #class AttributeResponse(StatusResponse):
def __init__(self, sec_context, attribute_converters, entity_id, # def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, log=None, timeslack=0, debug=0): # return_addr=None, log=None, timeslack=0, debug=0):
StatusResponse.__init__(self, sec_context, return_addr, log, timeslack, # StatusResponse.__init__(self, sec_context, return_addr, log, timeslack,
debug) # debug)
self.entity_id = entity_id # self.entity_id = entity_id
self.attribute_converters = attribute_converters # self.attribute_converters = attribute_converters
self.assertion = None # self.assertion = None
#
def get_identity(self): # def get_identity(self):
# The assertion can contain zero or one attributeStatements # # The assertion can contain zero or one attributeStatements
if not self.assertion.attribute_statement: # if not self.assertion.attribute_statement:
self.log.error("Missing Attribute Statement") # self.log.error("Missing Attribute Statement")
ava = {} # ava = {}
else: # else:
assert len(self.assertion.attribute_statement) == 1 # assert len(self.assertion.attribute_statement) == 1
#
if self.debug: # if self.debug:
self.log.info("Attribute Statement: %s" % ( # self.log.info("Attribute Statement: %s" % (
self.assertion.attribute_statement[0],)) # self.assertion.attribute_statement[0],))
for aconv in self.attribute_converters: # for aconv in self.attribute_converters:
self.log.info( # self.log.info(
"Converts name format: %s" % (aconv.name_format,)) # "Converts name format: %s" % (aconv.name_format,))
#
ava = to_local(self.attribute_converters, # ava = to_local(self.attribute_converters,
self.assertion.attribute_statement[0]) # self.assertion.attribute_statement[0])
return ava # return ava
#
# def session_info(self):
# """ Returns a predefined set of information gleened from the
# response.
# :returns: Dictionary with information
# """
# if self.session_not_on_or_after > 0:
# nooa = self.session_not_on_or_after
# else:
# nooa = self.not_on_or_after
#
# return { "ava": self.ava, "name_id": self.name_id,
# "came_from": self.came_from, "issuer": self.issuer(),
# "not_on_or_after": nooa,
# "authn_info": self.authn_info() }
class AuthnResponse(StatusResponse): class AuthnResponse(StatusResponse):
""" This is where all the profile complience is checked. """ This is where all the profile complience is checked.
This one does saml2int complience. """ This one does saml2int complience. """
def __init__(self, sec_context, attribute_converters, entity_id, def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, outstanding_queries=None, log=None, return_addr=None, outstanding_queries=None, log=None,
timeslack=0, debug=0): timeslack=0, debug=0, asynchop=False):
StatusResponse.__init__(self, sec_context, return_addr, log, StatusResponse.__init__(self, sec_context, return_addr, log,
timeslack, debug) timeslack, debug)
self.entity_id = entity_id self.entity_id = entity_id
@@ -279,11 +298,12 @@ class AuthnResponse(StatusResponse):
self.ava = None self.ava = None
self.assertion = None self.assertion = None
self.session_not_on_or_after = 0 self.session_not_on_or_after = 0
self.asynchop = asynchop
def loads(self, xmldata, decode=True):
self._loads(xmldata, decode) def loads(self, xmldata, decode=True, origxml=None):
self._loads(xmldata, decode, origxml)
if self.in_response_to in self.outstanding_queries: if self.asynchop and self.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[self.in_response_to] self.came_from = self.outstanding_queries[self.in_response_to]
del self.outstanding_queries[self.in_response_to] del self.outstanding_queries[self.in_response_to]
@@ -359,7 +379,11 @@ class AuthnResponse(StatusResponse):
return ava return ava
def get_subject(self): def get_subject(self):
# The assertion must contain a Subject """ The assertion must contain a Subject
:param asynch: If the connection is asynchronous there is
outstanding queries to connect to
"""
assert self.assertion.subject assert self.assertion.subject
subject = self.assertion.subject subject = self.assertion.subject
subjconf = [] subjconf = []
@@ -382,7 +406,7 @@ class AuthnResponse(StatusResponse):
if not time_util.later_than(data.not_on_or_after, data.not_before): if not time_util.later_than(data.not_on_or_after, data.not_before):
continue continue
if not self.came_from: if self.asynchop and not self.came_from:
if data.in_response_to in self.outstanding_queries: if data.in_response_to in self.outstanding_queries:
self.came_from = self.outstanding_queries[ self.came_from = self.outstanding_queries[
data.in_response_to] data.in_response_to]
@@ -420,7 +444,7 @@ class AuthnResponse(StatusResponse):
self.log.info("outstanding_queries: %s" % ( self.log.info("outstanding_queries: %s" % (
self.outstanding_queries,)) self.outstanding_queries,))
if self.context == "AuthNReq": if self.context == "AuthnReq":
self.authn_statement_ok() self.authn_statement_ok()
if not self.condition_ok(): if not self.condition_ok():
@@ -436,7 +460,7 @@ class AuthnResponse(StatusResponse):
try: try:
self.get_subject() self.get_subject()
if not self.came_from: if self.asynchop and not self.came_from:
return False return False
else: else:
return True return True
@@ -533,9 +557,21 @@ class AuthnResponse(StatusResponse):
def __str__(self): def __str__(self):
return "%s" % self.xmlstr return "%s" % self.xmlstr
class AttributeResponse(AuthnResponse):
def __init__(self, sec_context, attribute_converters, entity_id,
return_addr=None, log=None, timeslack=0, debug=0,
asynchop=True):
AuthnResponse.__init__(self, sec_context, return_addr, log, timeslack,
debug, asynchop)
self.entity_id = entity_id
self.attribute_converters = attribute_converters
self.assertion = None
self.context = "AttrQuery"
def response_factory(xmlstr, conf, return_addr=None, def response_factory(xmlstr, conf, return_addr=None,
outstanding_queries=None, log=None, outstanding_queries=None, log=None,
timeslack=0, debug=0, decode=True, request_id=0): timeslack=0, debug=0, decode=True, request_id=0,
origxml=None):
sec_context = security_context(conf) sec_context = security_context(conf)
if not timeslack: if not timeslack:
try: try:
@@ -549,7 +585,7 @@ def response_factory(xmlstr, conf, return_addr=None,
response = StatusResponse(sec_context, return_addr, log, timeslack, response = StatusResponse(sec_context, return_addr, log, timeslack,
debug, request_id) debug, request_id)
try: try:
response.loads(xmlstr, decode) response.loads(xmlstr, decode, origxml)
if response.response.assertion: if response.response.assertion:
authnresp = AuthnResponse(sec_context, attribute_converters, authnresp = AuthnResponse(sec_context, attribute_converters,
entity_id, return_addr, outstanding_queries, log, entity_id, return_addr, outstanding_queries, log,
@@ -558,7 +594,7 @@ def response_factory(xmlstr, conf, return_addr=None,
return authnresp return authnresp
except TypeError: except TypeError:
response.signature_check = sec_context.correctly_signed_logout_response response.signature_check = sec_context.correctly_signed_logout_response
response.loads(xmlstr, decode) response.loads(xmlstr, decode, origxml)
logoutresp = LogoutResponse(sec_context, return_addr, log, logoutresp = LogoutResponse(sec_context, return_addr, log,
timeslack, debug) timeslack, debug)
logoutresp.update(response) logoutresp.update(response)