Need to handle extension elements and status responses my be synchronous.

This commit is contained in:
Roland Hedberg
2013-01-10 14:50:52 +01:00
parent c7a8d3d63e
commit fd8c6994b2

View File

@@ -19,7 +19,10 @@ import calendar
import base64
import logging
from saml2 import samlp
import xmldsig as ds
import xmlenc as xenc
from saml2 import samlp, extension_elements_to_elements
from saml2 import saml
from saml2 import extension_element_to_element
from saml2 import time_util
@@ -33,7 +36,7 @@ from saml2.sigver import security_context
from saml2.sigver import SignatureError
from saml2.sigver import signed
from saml2.attribute_converter import to_local
from saml2.time_util import str_to_time
from saml2.time_util import str_to_time, later_than
from saml2.validate import validate_on_or_after
from saml2.validate import validate_before
@@ -53,7 +56,7 @@ class IncorrectlySigned(Exception):
def _dummy(_):
return None
def for_me(condition, myself ):
def for_me(condition, myself):
# Am I among the intended audiences
for restriction in condition.audience_restriction:
for audience in restriction.audience:
@@ -94,7 +97,7 @@ def attribute_response(conf, return_addr, timeslack=0, asynchop=False,
class StatusResponse(object):
def __init__(self, sec_context, return_addr=None, timeslack=0,
request_id=0):
request_id=0, asynchop=True):
self.sec = sec_context
self.return_addr = return_addr
@@ -108,6 +111,7 @@ class StatusResponse(object):
self.in_response_to = None
self.signature_check = self.sec.correctly_signed_response
self.not_signed = False
self.asynchop = asynchop
def _clear(self):
self.xmlstr = ""
@@ -204,11 +208,12 @@ class StatusResponse(object):
return None
assert self.response.version == "2.0"
if self.response.destination and \
self.response.destination != self.return_addr:
logger.error("%s != %s" % (self.response.destination,
self.return_addr))
return None
if self.asynchop:
if self.response.destination and \
self.response.destination != self.return_addr:
logger.error("%s != %s" % (self.response.destination,
self.return_addr))
return None
assert self.issue_instant_ok()
assert self.status_ok()
@@ -233,10 +238,14 @@ class StatusResponse(object):
return self.response.issuer.text.strip()
class LogoutResponse(StatusResponse):
def __init__(self, sec_context, return_addr=None, timeslack=0):
StatusResponse.__init__(self, sec_context, return_addr, timeslack)
def __init__(self, sec_context, return_addr=None, timeslack=0,
asynchop=True):
StatusResponse.__init__(self, sec_context, return_addr, timeslack,
asynchop=asynchop)
self.signature_check = self.sec.correctly_signed_logout_response
class AuthnResponse(StatusResponse):
""" This is where all the profile compliance is checked.
This one does saml2int compliance. """
@@ -246,7 +255,8 @@ class AuthnResponse(StatusResponse):
timeslack=0, asynchop=True, allow_unsolicited=False,
test=False):
StatusResponse.__init__(self, sec_context, return_addr, timeslack)
StatusResponse.__init__(self, sec_context, return_addr, timeslack,
asynchop=asynchop)
self.entity_id = entity_id
self.attribute_converters = attribute_converters
if outstanding_queries:
@@ -258,7 +268,6 @@ class AuthnResponse(StatusResponse):
self.ava = None
self.assertion = None
self.session_not_on_or_after = 0
self.asynchop = asynchop
self.allow_unsolicited = allow_unsolicited
self.test = test
@@ -311,20 +320,34 @@ class AuthnResponse(StatusResponse):
lax = True
assert self.assertion.conditions
condition = self.assertion.conditions
logger.debug("condition: %s" % condition)
# if no sub-elements or elements are supplied, then the
# assertion is considered to be valid.
if not condition.keyswv():
return True
# if both are present NotBefore must be earlier than NotOnOrAfter
if condition.not_before and condition.not_on_or_after:
if not later_than(condition.not_on_or_after, condition.not_before):
return False
try:
self.not_on_or_after = validate_on_or_after(
if condition.not_on_or_after:
self.not_on_or_after = validate_on_or_after(
condition.not_on_or_after,
self.timeslack)
validate_before(condition.not_before, self.timeslack)
if condition.not_before:
validate_before(condition.not_before, self.timeslack)
except Exception, excp:
logger.error("Exception on condition: %s" % (excp,))
if not lax:
raise
else:
self.not_on_or_after = 0
if not for_me(condition, self.entity_id):
if not lax:
#print condition
@@ -376,12 +399,19 @@ class AuthnResponse(StatusResponse):
return ava
def _bearer_confirmed(self, data):
# These two will raise exception if untrue
if not data:
return False
if data.address:
if not valid_address(data.address):
return False
# These two will raise exception if untrue
validate_on_or_after(data.not_on_or_after, self.timeslack)
validate_before(data.not_before, self.timeslack)
# not_before must be < not_on_or_after
if not time_util.later_than(data.not_on_or_after, data.not_before):
if not later_than(data.not_on_or_after, data.not_before):
return False
if self.asynchop and not self.came_from:
@@ -403,10 +433,16 @@ class AuthnResponse(StatusResponse):
return True
def _holder_of_key_confirmed(self, data):
if not data.key_info:
if not data:
return False
else:
return True
has_keyinfo = False
for element in extension_elements_to_elements(data,
[samlp, saml, xenc, ds]):
if isinstance(element, ds.KeyInfo):
has_keyinfo = True
return has_keyinfo
def get_subject(self):
""" The assertion must contain a Subject
@@ -415,21 +451,13 @@ class AuthnResponse(StatusResponse):
subject = self.assertion.subject
subjconf = []
for subject_confirmation in subject.subject_confirmation:
data = subject_confirmation.subject_confirmation_data
if not data:
# I don't know where this belongs so I ignore it
continue
if data.address:
if not valid_address(data.address):
# ignore this subject_confirmation
continue
_data = subject_confirmation.subject_confirmation_data
if subject_confirmation.method == SCM_BEARER:
if not self._bearer_confirmed(data):
if not self._bearer_confirmed(_data):
continue
elif subject_confirmation.method == SCM_HOLDER_OF_KEY:
if not self._holder_of_key_confirmed(data):
if not self._holder_of_key_confirmed(_data):
continue
elif subject_confirmation.method == SCM_SENDER_VOUCHES:
pass
@@ -643,7 +671,8 @@ def response_factory(xmlstr, conf, return_addr=None,
attribute_converters = conf.attribute_converters
entity_id = conf.entityid
response = StatusResponse(sec_context, return_addr, timeslack, request_id)
response = StatusResponse(sec_context, return_addr, timeslack, request_id,
asynchop)
try:
response.loads(xmlstr, decode, origxml)
if response.response.assertion or response.response.encrypted_assertion:
@@ -655,7 +684,8 @@ def response_factory(xmlstr, conf, return_addr=None,
except TypeError:
response.signature_check = sec_context.correctly_signed_logout_response
response.loads(xmlstr, decode, origxml)
logoutresp = LogoutResponse(sec_context, return_addr, timeslack)
logoutresp = LogoutResponse(sec_context, return_addr, timeslack,
asynchop=asynchop)
logoutresp.update(response)
return logoutresp