refactored plus new test

This commit is contained in:
Roland Hedberg
2009-11-03 20:09:17 +01:00
parent c965acac34
commit e7abf65f66

View File

@@ -31,11 +31,15 @@ class UnknownPricipal(Exception):
class UnsupportedBinding(Exception): class UnsupportedBinding(Exception):
pass pass
class OtherError(Exception):
pass
EXCEPTION2STATUS = { EXCEPTION2STATUS = {
VersionMismatch: samlp.STATUS_VERSION_MISMATCH, VersionMismatch: samlp.STATUS_VERSION_MISMATCH,
UnknownPricipal: samlp.STATUS_UNKNOWN_PRINCIPAL, UnknownPricipal: samlp.STATUS_UNKNOWN_PRINCIPAL,
UnsupportedBinding: samlp.STATUS_UNSUPPORTED_BINDING, UnsupportedBinding: samlp.STATUS_UNSUPPORTED_BINDING,
OtherError: samlp.STATUS_UNKNOWN_PRINCIPAL,
} }
def properties(klass): def properties(klass):
@@ -59,11 +63,11 @@ class Server(object):
#assert "service_url" in self.conf #assert "service_url" in self.conf
assert "entityid" in self.conf assert "entityid" in self.conf
if "my_key" not in self.conf: if "key_file" not in self.conf:
self.conf["my_key"] = None self.conf["key_file"] = None
else: else:
# If you have a key file you have to have a cert file # If you have a key file you have to have a cert file
assert "my_cert" in self.conf assert "cert_file" in self.conf
if "metadata" in self.conf: if "metadata" in self.conf:
md = MetaData() md = MetaData()
@@ -88,12 +92,12 @@ class Server(object):
def status_from_exception(self, exception): def status_from_exception(self, exception):
return { return {
"status_code": { "status_code": {
"value": STATUS_RESPONDER, "value": samlp.STATUS_RESPONDER,
"status_code": { "status_code": {
"value": EXCEPTION2STATUS( exception), "value": EXCEPTION2STATUS[exception.__class__],
}, },
"message": exception.args[0], },
} "status_message": exception.args[0],
} }
def status(self, status, message=None, status_code=None): def status(self, status, message=None, status_code=None):
@@ -225,6 +229,10 @@ class Server(object):
if not consumer_url: # what to do ? if not consumer_url: # what to do ?
raise UnsupportedBinding(spentityid) raise UnsupportedBinding(spentityid)
if consumer_url != return_destination:
# serious error on someones behalf
raise OtherError("ConsumerURL and return destination mismatch")
policy = request.name_id_policy policy = request.name_id_policy
if policy.allow_create.lower() == "true" and \ if policy.allow_create.lower() == "true" and \
policy.format == saml.NAMEID_FORMAT_TRANSIENT: policy.format == saml.NAMEID_FORMAT_TRANSIENT: