diff --git a/src/saml2/server.py b/src/saml2/server.py index 94b1136..f8623d9 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -48,13 +48,22 @@ class UnknownVO(Exception): class Identifier(object): """ A class that handles identifiers of objects """ - def __init__(self, dbname, entityid, voconf=None, debug=0, log=None): + def __init__(self, dbname, voconf=None, debug=0, log=None): self.map = shelve.open(dbname, writeback=True) - self.entityid = entityid self.voconf = voconf self.debug = debug self.log = log + def _store(self, typ, entity_id, local, remote): + self.map["|".join([typ, entity_id, "f", local])] = remote + self.map["|".join([typ, entity_id, "b", remote])] = local + + def _get_remote(self, typ, entity_id, local): + return self.map["|".join([typ, entity_id, "f", local])] + + def _get_local(self, typ, entity_id, remote): + return self.map["|".join([typ, entity_id, "b", remote])] + def persistent(self, entity_id, subject_id): """ Keeps the link between a permanent identifier and a temporary/pseudo-temporary identifier for a subject @@ -68,26 +77,16 @@ class Identifier(object): :return: An arbitrary identifier for the subject unique to the service/group of services/VO with a given entity_id """ - if self.debug: - self.log and self.log.debug("Id map keys: %s" % self.map.keys()) - try: - emap = self.map[entity_id] - except KeyError: - emap = self.map[entity_id] = {"forward":{}, "backward":{}} - - try: - if self.debug: - self.log.debug("map forward keys: %s" % emap["forward"].keys()) - return emap["forward"][subject_id] + return self._get_remote("persistent", entity_id, subject_id) except KeyError: while True: temp_id = sid() - if temp_id not in emap["backward"]: + try: + l = self._get_local("persistent", entity_id, temp_id) + except KeyError: break - emap["forward"][subject_id] = temp_id - emap["backward"][temp_id] = subject_id - self.map[entity_id] = emap + self._store("persistent", entity_id, subject_id, temp_id) self.map.sync() return temp_id @@ -113,8 +112,6 @@ class Identifier(object): return saml.NameID(format=nameid_format, sp_name_qualifier=sp_name_qualifier, text=subj_id) - # return args2dict(subj_id, format=nameid_format, - # sp_name_qualifier=sp_name_qualifier) def persistent_nameid(self, sp_name_qualifier, userid): """ Get or create a persistent identifier for this object to be used @@ -128,15 +125,20 @@ class Identifier(object): return saml.NameID(format=saml.NAMEID_FORMAT_PERSISTENT, sp_name_qualifier=sp_name_qualifier, text=subj_id) - - # return args2dict(subj_id, format=saml.NAMEID_FORMAT_PERSISTENT, - # sp_name_qualifier=sp_name_qualifier) - def temporary_nameid(self): + def transient_nameid(self, sp_name_qualifier, userid): """ Returns a random one-time identifier """ + while True: + temp_id = sid() + try: + l = self._get_local("transient", sp_name_qualifier, temp_id) + except KeyError: + break + self._store("transient", sp_name_qualifier, userid, temp_id) + self.map.sync() + return saml.NameID(format=saml.NAMEID_FORMAT_TRANSIENT, - text=sid()) - #return args2dict(sid(), format=saml.NAMEID_FORMAT_TRANSIENT) + text=temp_id) def construct_nameid(self, local_policy, userid, sp_entity_id, identity=None, name_id_policy=None): @@ -157,11 +159,24 @@ class Identifier(object): else: nameid_format = local_policy.get_nameid_format(sp_entity_id) if nameid_format == saml.NAMEID_FORMAT_PERSISTENT: - return self.persistent_nameid(self.entityid, userid) + return self.persistent_nameid(sp_entity_id, userid) elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT: - return self.temporary_nameid() + return self.transient_nameid(sp_entity_id, userid) + def local_name(self, entity_id, remote_id): + """ Only works for persistent names + :param entity_id: The identifier of the entity that got the remote id + :param remote_id: The identifier that was exported + :return: Local identifier + """ + try: + return self._get_local("persistent", entity_id, remote_id) + except KeyError: + try: + return self._get_local("transient", entity_id, remote_id) + except KeyError: + return None class Server(object): """ A class that does things that IdPs or AAs do """ @@ -192,7 +207,7 @@ class Server(object): self.conf.load_file(config_file) if "subject_data" in self.conf: self.ident = Identifier(self.conf["subject_data"], - self.conf["entityid"], self.conf.vo_conf, + self.conf.vo_conf, self.debug, self.log) else: self.ident = None