Pulled out all name_id related stuff into an own class

This commit is contained in:
Roland Hedberg
2010-03-26 15:16:29 +01:00
parent c4cbb60e41
commit b2e61b200d

View File

@@ -38,9 +38,14 @@ from saml2.config import Config
from saml2.cache import Cache from saml2.cache import Cache
from saml2.assertion import Assertion, Policy from saml2.assertion import Assertion, Policy
class IdentifierMap(object): class UnknownVO(Exception):
def __init__(self, dbname, debug=0, log=None): pass
class Identifier(object):
def __init__(self, dbname, entityid, voconf=None, debug=0, log=None):
self.map = shelve.open(dbname,writeback=True) self.map = shelve.open(dbname,writeback=True)
self.entityid = entityid
self.voconf = voconf
self.debug = debug self.debug = debug
self.log = log self.log = log
@@ -50,7 +55,7 @@ class IdentifierMap(object):
:param entity_id: SP entity ID or VO entity ID :param entity_id: SP entity ID or VO entity ID
:param subject_id: The local identifier of the subject :param subject_id: The local identifier of the subject
:return: A arbitrary identifier for the subject unique to the :return: An arbitrary identifier for the subject unique to the
entity_id entity_id
""" """
if self.debug: if self.debug:
@@ -77,6 +82,48 @@ class IdentifierMap(object):
return temp_id return temp_id
def _get_vo_identifier(self, sp_name_qualifier, userid, identity):
try:
vo_conf = self.voconf(sp_name_qualifier)
if "common_identifier" in vo_conf:
try:
subj_id = identity[vo_conf["common_identifier"]]
except KeyError:
raise MissingValue("Common identifier")
else:
return self.persistent_nameid(sp_name_qualifier, userid)
except KeyError:
raise UnknownVO("%s" % sp_name_qualifier)
try:
format = vo_conf["nameid_format"]
except KeyError:
format = saml.NAMEID_FORMAT_PERSISTENT
return args2dict(subj_id, format=format,
sp_name_qualifier=sp_name_qualifier)
def persistent_nameid(self, sp_name_qualifier, userid):
subj_id = self.persistent(sp_name_qualifier, userid)
return args2dict(subj_id, format=saml.NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier=sp_name_qualifier)
def construct_nameid(self, local_policy, userid, sp_entity_id,
identity=None, name_id_policy=None):
if name_id_policy and name_id_policy.sp_name_qualifier:
return self._get_vo_identifier(name_id_policy.sp_name_qualifier,
userid, identity)
else:
nameid_format = local_policy.get_nameid_format(sp_entity_id)
if nameid_format == saml.NAMEID_FORMAT_PERSISTENT:
return self.persistent_nameid(self.entityid, userid)
elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT:
return self.temporary_nameid()
def temporary_nameid(self):
return args2dict(sid(), format=saml.NAMEID_FORMAT_TRANSIENT)
class Server(object): class Server(object):
""" A class that does things that IdPs or AAs do """ """ A class that does things that IdPs or AAs do """
def __init__(self, config_file="", config=None, cache="", def __init__(self, config_file="", config=None, cache="",
@@ -99,8 +146,9 @@ class Server(object):
self.conf = Config() self.conf = Config()
self.conf.load_file(config_file) self.conf.load_file(config_file)
if "subject_data" in self.conf: if "subject_data" in self.conf:
self.id = IdentifierMap(self.conf["subject_data"], self.id = Identifier(self.conf["subject_data"],
self.debug, self.log) self.conf["entityid"], self.conf.vo_conf,
self.debug, self.log)
else: else:
self.id = None self.id = None
@@ -276,8 +324,12 @@ class Server(object):
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def do_aa_response(self, consumer_url, in_response_to, sp_entity_id, def do_aa_response(self, consumer_url, in_response_to, sp_entity_id,
identity=None, name_id=None, ip_address="", identity=None, userid="", name_id=None, ip_address="",
issuer=None, status=None, sign=False): issuer=None, status=None, sign=False,
name_id_policy=None):
name_id = self.id.construct_nameid(self.conf.aa_policy(), userid,
sp_entity_id, identity)
return self._response(consumer_url, in_response_to, return self._response(consumer_url, in_response_to,
sp_entity_id, identity, name_id, sp_entity_id, identity, name_id,
@@ -285,48 +337,44 @@ class Server(object):
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
def authn_response(self, identity, in_response_to, destination, spid,
name_id_policy, userid): # ------------------------------------------------------------------------
def authn_response(self, identity, in_response_to, destination,
sp_entity_id, name_id_policy, userid):
""" Constructs an AuthenticationResponse """ Constructs an AuthenticationResponse
:param identity: Information about an user :param identity: Information about an user
:param in_response_to: The identifier of the authentication request :param in_response_to: The identifier of the authentication request
this response is an answer to. this response is an answer to.
:param destination: Where the response should be sent :param destination: Where the response should be sent
:param sid: The entity identifier of the Service Provider :param sp_entity_id: The entity identifier of the Service Provider
:param name_id_policy: ... :param name_id_policy: ...
:param userid: The subject identifier :param userid: The subject identifier
:return: A XML string representing an authentication response :return: A XML string representing an authentication response
""" """
name_id = None try:
if name_id_policy.sp_name_qualifier: name_id = self.id.construct_nameid(self.conf.idp_policy(),
try: userid, sp_entity_id, identity,
vo_conf = self.conf.vo_conf(name_id_policy.sp_name_qualifier) name_id_policy)
subj_id = identity[vo_conf["common_identifier"]] except IOError, exc:
except KeyError: response = self.error_response(destination, in_response_to,
self.log.info( sp_entity_id, exc, name_id)
"Get persistent ID (%s,%s)" % ( return ("%s" % response).split("\n")
name_id_policy.sp_name_qualifier,userid))
subj_id = self.id.persistent(name_id_policy.sp_name_qualifier,
userid)
self.log.info("=> %s" % subj_id)
name_id = args2dict(subj_id,
format=saml.NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier=name_id_policy.sp_name_qualifier)
try: try:
resp = self.do_response( response = self.do_response(
destination, # consumer_url destination, # consumer_url
in_response_to, # in_response_to in_response_to, # in_response_to
spid, # sp_entity_id sp_entity_id, # sp_entity_id
identity, # identity as dictionary identity, # identity as dictionary
name_id, name_id,
userid
) )
except MissingValue, exc: except MissingValue, exc:
resp = self.error_response( destination, in_response_to, spid, response = self.error_response(destination, in_response_to,
exc, name_id) sp_entity_id, exc, name_id)
return ("%s" % resp).split("\n") return ("%s" % response).split("\n")