Pulled out all name_id related stuff into an own class

This commit is contained in:
Roland Hedberg
2010-03-26 15:16:29 +01:00
parent c4cbb60e41
commit b2e61b200d

View File

@@ -38,9 +38,14 @@ from saml2.config import Config
from saml2.cache import Cache
from saml2.assertion import Assertion, Policy
class IdentifierMap(object):
def __init__(self, dbname, debug=0, log=None):
class UnknownVO(Exception):
pass
class Identifier(object):
def __init__(self, dbname, entityid, voconf=None, debug=0, log=None):
self.map = shelve.open(dbname,writeback=True)
self.entityid = entityid
self.voconf = voconf
self.debug = debug
self.log = log
@@ -50,7 +55,7 @@ class IdentifierMap(object):
:param entity_id: SP entity ID or VO entity ID
:param subject_id: The local identifier of the subject
:return: A arbitrary identifier for the subject unique to the
:return: An arbitrary identifier for the subject unique to the
entity_id
"""
if self.debug:
@@ -76,6 +81,48 @@ class IdentifierMap(object):
self.map.sync()
return temp_id
def _get_vo_identifier(self, sp_name_qualifier, userid, identity):
try:
vo_conf = self.voconf(sp_name_qualifier)
if "common_identifier" in vo_conf:
try:
subj_id = identity[vo_conf["common_identifier"]]
except KeyError:
raise MissingValue("Common identifier")
else:
return self.persistent_nameid(sp_name_qualifier, userid)
except KeyError:
raise UnknownVO("%s" % sp_name_qualifier)
try:
format = vo_conf["nameid_format"]
except KeyError:
format = saml.NAMEID_FORMAT_PERSISTENT
return args2dict(subj_id, format=format,
sp_name_qualifier=sp_name_qualifier)
def persistent_nameid(self, sp_name_qualifier, userid):
subj_id = self.persistent(sp_name_qualifier, userid)
return args2dict(subj_id, format=saml.NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier=sp_name_qualifier)
def construct_nameid(self, local_policy, userid, sp_entity_id,
identity=None, name_id_policy=None):
if name_id_policy and name_id_policy.sp_name_qualifier:
return self._get_vo_identifier(name_id_policy.sp_name_qualifier,
userid, identity)
else:
nameid_format = local_policy.get_nameid_format(sp_entity_id)
if nameid_format == saml.NAMEID_FORMAT_PERSISTENT:
return self.persistent_nameid(self.entityid, userid)
elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT:
return self.temporary_nameid()
def temporary_nameid(self):
return args2dict(sid(), format=saml.NAMEID_FORMAT_TRANSIENT)
class Server(object):
""" A class that does things that IdPs or AAs do """
@@ -99,8 +146,9 @@ class Server(object):
self.conf = Config()
self.conf.load_file(config_file)
if "subject_data" in self.conf:
self.id = IdentifierMap(self.conf["subject_data"],
self.debug, self.log)
self.id = Identifier(self.conf["subject_data"],
self.conf["entityid"], self.conf.vo_conf,
self.debug, self.log)
else:
self.id = None
@@ -190,9 +238,9 @@ class Server(object):
def find_subject(self, subject, attribute=None):
pass
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
def _response(self, consumer_url, in_response_to, sp_entity_id,
identity=None, name_id=None, status=None, sign=False,
policy=Policy()):
@@ -276,57 +324,57 @@ class Server(object):
# ------------------------------------------------------------------------
def do_aa_response(self, consumer_url, in_response_to, sp_entity_id,
identity=None, name_id=None, ip_address="",
issuer=None, status=None, sign=False):
identity=None, userid="", name_id=None, ip_address="",
issuer=None, status=None, sign=False,
name_id_policy=None):
name_id = self.id.construct_nameid(self.conf.aa_policy(), userid,
sp_entity_id, identity)
return self._response(consumer_url, in_response_to,
sp_entity_id, identity, name_id,
status, sign, policy=self.conf.aa_policy())
# ------------------------------------------------------------------------
def authn_response(self, identity, in_response_to, destination, spid,
name_id_policy, userid):
# ------------------------------------------------------------------------
def authn_response(self, identity, in_response_to, destination,
sp_entity_id, name_id_policy, userid):
""" Constructs an AuthenticationResponse
:param identity: Information about an user
:param in_response_to: The identifier of the authentication request
this response is an answer to.
:param destination: Where the response should be sent
:param sid: The entity identifier of the Service Provider
:param sp_entity_id: The entity identifier of the Service Provider
:param name_id_policy: ...
:param userid: The subject identifier
:return: A XML string representing an authentication response
"""
name_id = None
if name_id_policy.sp_name_qualifier:
try:
vo_conf = self.conf.vo_conf(name_id_policy.sp_name_qualifier)
subj_id = identity[vo_conf["common_identifier"]]
except KeyError:
self.log.info(
"Get persistent ID (%s,%s)" % (
name_id_policy.sp_name_qualifier,userid))
subj_id = self.id.persistent(name_id_policy.sp_name_qualifier,
userid)
self.log.info("=> %s" % subj_id)
name_id = args2dict(subj_id,
format=saml.NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier=name_id_policy.sp_name_qualifier)
try:
name_id = self.id.construct_nameid(self.conf.idp_policy(),
userid, sp_entity_id, identity,
name_id_policy)
except IOError, exc:
response = self.error_response(destination, in_response_to,
sp_entity_id, exc, name_id)
return ("%s" % response).split("\n")
try:
resp = self.do_response(
response = self.do_response(
destination, # consumer_url
in_response_to, # in_response_to
spid, # sp_entity_id
sp_entity_id, # sp_entity_id
identity, # identity as dictionary
name_id,
userid
)
except MissingValue, exc:
resp = self.error_response( destination, in_response_to, spid,
exc, name_id)
response = self.error_response(destination, in_response_to,
sp_entity_id, exc, name_id)
return ("%s" % resp).split("\n")
return ("%s" % response).split("\n")