From b2e61b200db891bdf7b9bf1f019120b8daf3b1c6 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Fri, 26 Mar 2010 15:16:29 +0100 Subject: [PATCH] Pulled out all name_id related stuff into an own class --- src/saml2/server.py | 114 +++++++++++++++++++++++++++++++------------- 1 file changed, 81 insertions(+), 33 deletions(-) diff --git a/src/saml2/server.py b/src/saml2/server.py index bac4fa9..d5e6d4d 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -38,9 +38,14 @@ from saml2.config import Config from saml2.cache import Cache from saml2.assertion import Assertion, Policy -class IdentifierMap(object): - def __init__(self, dbname, debug=0, log=None): +class UnknownVO(Exception): + pass + +class Identifier(object): + def __init__(self, dbname, entityid, voconf=None, debug=0, log=None): self.map = shelve.open(dbname,writeback=True) + self.entityid = entityid + self.voconf = voconf self.debug = debug self.log = log @@ -50,7 +55,7 @@ class IdentifierMap(object): :param entity_id: SP entity ID or VO entity ID :param subject_id: The local identifier of the subject - :return: A arbitrary identifier for the subject unique to the + :return: An arbitrary identifier for the subject unique to the entity_id """ if self.debug: @@ -76,6 +81,48 @@ class IdentifierMap(object): self.map.sync() return temp_id + + def _get_vo_identifier(self, sp_name_qualifier, userid, identity): + try: + vo_conf = self.voconf(sp_name_qualifier) + if "common_identifier" in vo_conf: + try: + subj_id = identity[vo_conf["common_identifier"]] + except KeyError: + raise MissingValue("Common identifier") + else: + return self.persistent_nameid(sp_name_qualifier, userid) + except KeyError: + raise UnknownVO("%s" % sp_name_qualifier) + + try: + format = vo_conf["nameid_format"] + except KeyError: + format = saml.NAMEID_FORMAT_PERSISTENT + + return args2dict(subj_id, format=format, + sp_name_qualifier=sp_name_qualifier) + + def persistent_nameid(self, sp_name_qualifier, userid): + subj_id = self.persistent(sp_name_qualifier, userid) + return args2dict(subj_id, format=saml.NAMEID_FORMAT_PERSISTENT, + sp_name_qualifier=sp_name_qualifier) + + def construct_nameid(self, local_policy, userid, sp_entity_id, + identity=None, name_id_policy=None): + if name_id_policy and name_id_policy.sp_name_qualifier: + return self._get_vo_identifier(name_id_policy.sp_name_qualifier, + userid, identity) + else: + nameid_format = local_policy.get_nameid_format(sp_entity_id) + if nameid_format == saml.NAMEID_FORMAT_PERSISTENT: + return self.persistent_nameid(self.entityid, userid) + elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT: + return self.temporary_nameid() + + def temporary_nameid(self): + return args2dict(sid(), format=saml.NAMEID_FORMAT_TRANSIENT) + class Server(object): """ A class that does things that IdPs or AAs do """ @@ -99,8 +146,9 @@ class Server(object): self.conf = Config() self.conf.load_file(config_file) if "subject_data" in self.conf: - self.id = IdentifierMap(self.conf["subject_data"], - self.debug, self.log) + self.id = Identifier(self.conf["subject_data"], + self.conf["entityid"], self.conf.vo_conf, + self.debug, self.log) else: self.id = None @@ -190,9 +238,9 @@ class Server(object): def find_subject(self, subject, attribute=None): pass - - # ------------------------------------------------------------------------ + # ------------------------------------------------------------------------ + def _response(self, consumer_url, in_response_to, sp_entity_id, identity=None, name_id=None, status=None, sign=False, policy=Policy()): @@ -276,57 +324,57 @@ class Server(object): # ------------------------------------------------------------------------ def do_aa_response(self, consumer_url, in_response_to, sp_entity_id, - identity=None, name_id=None, ip_address="", - issuer=None, status=None, sign=False): + identity=None, userid="", name_id=None, ip_address="", + issuer=None, status=None, sign=False, + name_id_policy=None): + name_id = self.id.construct_nameid(self.conf.aa_policy(), userid, + sp_entity_id, identity) + return self._response(consumer_url, in_response_to, sp_entity_id, identity, name_id, status, sign, policy=self.conf.aa_policy()) # ------------------------------------------------------------------------ - def authn_response(self, identity, in_response_to, destination, spid, - name_id_policy, userid): + + # ------------------------------------------------------------------------ + + def authn_response(self, identity, in_response_to, destination, + sp_entity_id, name_id_policy, userid): """ Constructs an AuthenticationResponse :param identity: Information about an user :param in_response_to: The identifier of the authentication request this response is an answer to. :param destination: Where the response should be sent - :param sid: The entity identifier of the Service Provider + :param sp_entity_id: The entity identifier of the Service Provider :param name_id_policy: ... :param userid: The subject identifier :return: A XML string representing an authentication response """ - name_id = None - if name_id_policy.sp_name_qualifier: - try: - vo_conf = self.conf.vo_conf(name_id_policy.sp_name_qualifier) - subj_id = identity[vo_conf["common_identifier"]] - except KeyError: - self.log.info( - "Get persistent ID (%s,%s)" % ( - name_id_policy.sp_name_qualifier,userid)) - subj_id = self.id.persistent(name_id_policy.sp_name_qualifier, - userid) - self.log.info("=> %s" % subj_id) - - name_id = args2dict(subj_id, - format=saml.NAMEID_FORMAT_PERSISTENT, - sp_name_qualifier=name_id_policy.sp_name_qualifier) + try: + name_id = self.id.construct_nameid(self.conf.idp_policy(), + userid, sp_entity_id, identity, + name_id_policy) + except IOError, exc: + response = self.error_response(destination, in_response_to, + sp_entity_id, exc, name_id) + return ("%s" % response).split("\n") try: - resp = self.do_response( + response = self.do_response( destination, # consumer_url in_response_to, # in_response_to - spid, # sp_entity_id + sp_entity_id, # sp_entity_id identity, # identity as dictionary name_id, + userid ) except MissingValue, exc: - resp = self.error_response( destination, in_response_to, spid, - exc, name_id) + response = self.error_response(destination, in_response_to, + sp_entity_id, exc, name_id) - return ("%s" % resp).split("\n") \ No newline at end of file + return ("%s" % response).split("\n") \ No newline at end of file