Closer to memcache. Have to store transient identifier at least as long as the session is active

This commit is contained in:
Roland Hedberg
2010-10-13 16:28:43 +02:00
parent f19b1700fe
commit 3cbe3074fc

View File

@@ -48,13 +48,22 @@ class UnknownVO(Exception):
class Identifier(object):
""" A class that handles identifiers of objects """
def __init__(self, dbname, entityid, voconf=None, debug=0, log=None):
def __init__(self, dbname, voconf=None, debug=0, log=None):
self.map = shelve.open(dbname, writeback=True)
self.entityid = entityid
self.voconf = voconf
self.debug = debug
self.log = log
def _store(self, typ, entity_id, local, remote):
self.map["|".join([typ, entity_id, "f", local])] = remote
self.map["|".join([typ, entity_id, "b", remote])] = local
def _get_remote(self, typ, entity_id, local):
return self.map["|".join([typ, entity_id, "f", local])]
def _get_local(self, typ, entity_id, remote):
return self.map["|".join([typ, entity_id, "b", remote])]
def persistent(self, entity_id, subject_id):
""" Keeps the link between a permanent identifier and a
temporary/pseudo-temporary identifier for a subject
@@ -68,26 +77,16 @@ class Identifier(object):
:return: An arbitrary identifier for the subject unique to the
service/group of services/VO with a given entity_id
"""
if self.debug:
self.log and self.log.debug("Id map keys: %s" % self.map.keys())
try:
emap = self.map[entity_id]
except KeyError:
emap = self.map[entity_id] = {"forward":{}, "backward":{}}
try:
if self.debug:
self.log.debug("map forward keys: %s" % emap["forward"].keys())
return emap["forward"][subject_id]
return self._get_remote("persistent", entity_id, subject_id)
except KeyError:
while True:
temp_id = sid()
if temp_id not in emap["backward"]:
try:
l = self._get_local("persistent", entity_id, temp_id)
except KeyError:
break
emap["forward"][subject_id] = temp_id
emap["backward"][temp_id] = subject_id
self.map[entity_id] = emap
self._store("persistent", entity_id, subject_id, temp_id)
self.map.sync()
return temp_id
@@ -113,8 +112,6 @@ class Identifier(object):
return saml.NameID(format=nameid_format,
sp_name_qualifier=sp_name_qualifier,
text=subj_id)
# return args2dict(subj_id, format=nameid_format,
# sp_name_qualifier=sp_name_qualifier)
def persistent_nameid(self, sp_name_qualifier, userid):
""" Get or create a persistent identifier for this object to be used
@@ -128,15 +125,20 @@ class Identifier(object):
return saml.NameID(format=saml.NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier=sp_name_qualifier,
text=subj_id)
# return args2dict(subj_id, format=saml.NAMEID_FORMAT_PERSISTENT,
# sp_name_qualifier=sp_name_qualifier)
def temporary_nameid(self):
def transient_nameid(self, sp_name_qualifier, userid):
""" Returns a random one-time identifier """
while True:
temp_id = sid()
try:
l = self._get_local("transient", sp_name_qualifier, temp_id)
except KeyError:
break
self._store("transient", sp_name_qualifier, userid, temp_id)
self.map.sync()
return saml.NameID(format=saml.NAMEID_FORMAT_TRANSIENT,
text=sid())
#return args2dict(sid(), format=saml.NAMEID_FORMAT_TRANSIENT)
text=temp_id)
def construct_nameid(self, local_policy, userid, sp_entity_id,
identity=None, name_id_policy=None):
@@ -157,11 +159,24 @@ class Identifier(object):
else:
nameid_format = local_policy.get_nameid_format(sp_entity_id)
if nameid_format == saml.NAMEID_FORMAT_PERSISTENT:
return self.persistent_nameid(self.entityid, userid)
return self.persistent_nameid(sp_entity_id, userid)
elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT:
return self.temporary_nameid()
return self.transient_nameid(sp_entity_id, userid)
def local_name(self, entity_id, remote_id):
""" Only works for persistent names
:param entity_id: The identifier of the entity that got the remote id
:param remote_id: The identifier that was exported
:return: Local identifier
"""
try:
return self._get_local("persistent", entity_id, remote_id)
except KeyError:
try:
return self._get_local("transient", entity_id, remote_id)
except KeyError:
return None
class Server(object):
""" A class that does things that IdPs or AAs do """
@@ -192,7 +207,7 @@ class Server(object):
self.conf.load_file(config_file)
if "subject_data" in self.conf:
self.ident = Identifier(self.conf["subject_data"],
self.conf["entityid"], self.conf.vo_conf,
self.conf.vo_conf,
self.debug, self.log)
else:
self.ident = None