Pulled out all name_id related stuff into an own class
This commit is contained in:
@@ -38,9 +38,14 @@ from saml2.config import Config
|
||||
from saml2.cache import Cache
|
||||
from saml2.assertion import Assertion, Policy
|
||||
|
||||
class IdentifierMap(object):
|
||||
def __init__(self, dbname, debug=0, log=None):
|
||||
class UnknownVO(Exception):
|
||||
pass
|
||||
|
||||
class Identifier(object):
|
||||
def __init__(self, dbname, entityid, voconf=None, debug=0, log=None):
|
||||
self.map = shelve.open(dbname,writeback=True)
|
||||
self.entityid = entityid
|
||||
self.voconf = voconf
|
||||
self.debug = debug
|
||||
self.log = log
|
||||
|
||||
@@ -50,7 +55,7 @@ class IdentifierMap(object):
|
||||
|
||||
:param entity_id: SP entity ID or VO entity ID
|
||||
:param subject_id: The local identifier of the subject
|
||||
:return: A arbitrary identifier for the subject unique to the
|
||||
:return: An arbitrary identifier for the subject unique to the
|
||||
entity_id
|
||||
"""
|
||||
if self.debug:
|
||||
@@ -76,6 +81,48 @@ class IdentifierMap(object):
|
||||
self.map.sync()
|
||||
|
||||
return temp_id
|
||||
|
||||
def _get_vo_identifier(self, sp_name_qualifier, userid, identity):
|
||||
try:
|
||||
vo_conf = self.voconf(sp_name_qualifier)
|
||||
if "common_identifier" in vo_conf:
|
||||
try:
|
||||
subj_id = identity[vo_conf["common_identifier"]]
|
||||
except KeyError:
|
||||
raise MissingValue("Common identifier")
|
||||
else:
|
||||
return self.persistent_nameid(sp_name_qualifier, userid)
|
||||
except KeyError:
|
||||
raise UnknownVO("%s" % sp_name_qualifier)
|
||||
|
||||
try:
|
||||
format = vo_conf["nameid_format"]
|
||||
except KeyError:
|
||||
format = saml.NAMEID_FORMAT_PERSISTENT
|
||||
|
||||
return args2dict(subj_id, format=format,
|
||||
sp_name_qualifier=sp_name_qualifier)
|
||||
|
||||
def persistent_nameid(self, sp_name_qualifier, userid):
|
||||
subj_id = self.persistent(sp_name_qualifier, userid)
|
||||
return args2dict(subj_id, format=saml.NAMEID_FORMAT_PERSISTENT,
|
||||
sp_name_qualifier=sp_name_qualifier)
|
||||
|
||||
def construct_nameid(self, local_policy, userid, sp_entity_id,
|
||||
identity=None, name_id_policy=None):
|
||||
if name_id_policy and name_id_policy.sp_name_qualifier:
|
||||
return self._get_vo_identifier(name_id_policy.sp_name_qualifier,
|
||||
userid, identity)
|
||||
else:
|
||||
nameid_format = local_policy.get_nameid_format(sp_entity_id)
|
||||
if nameid_format == saml.NAMEID_FORMAT_PERSISTENT:
|
||||
return self.persistent_nameid(self.entityid, userid)
|
||||
elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT:
|
||||
return self.temporary_nameid()
|
||||
|
||||
def temporary_nameid(self):
|
||||
return args2dict(sid(), format=saml.NAMEID_FORMAT_TRANSIENT)
|
||||
|
||||
|
||||
class Server(object):
|
||||
""" A class that does things that IdPs or AAs do """
|
||||
@@ -99,8 +146,9 @@ class Server(object):
|
||||
self.conf = Config()
|
||||
self.conf.load_file(config_file)
|
||||
if "subject_data" in self.conf:
|
||||
self.id = IdentifierMap(self.conf["subject_data"],
|
||||
self.debug, self.log)
|
||||
self.id = Identifier(self.conf["subject_data"],
|
||||
self.conf["entityid"], self.conf.vo_conf,
|
||||
self.debug, self.log)
|
||||
else:
|
||||
self.id = None
|
||||
|
||||
@@ -190,9 +238,9 @@ class Server(object):
|
||||
|
||||
def find_subject(self, subject, attribute=None):
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def _response(self, consumer_url, in_response_to, sp_entity_id,
|
||||
identity=None, name_id=None, status=None, sign=False,
|
||||
policy=Policy()):
|
||||
@@ -276,57 +324,57 @@ class Server(object):
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def do_aa_response(self, consumer_url, in_response_to, sp_entity_id,
|
||||
identity=None, name_id=None, ip_address="",
|
||||
issuer=None, status=None, sign=False):
|
||||
identity=None, userid="", name_id=None, ip_address="",
|
||||
issuer=None, status=None, sign=False,
|
||||
name_id_policy=None):
|
||||
|
||||
name_id = self.id.construct_nameid(self.conf.aa_policy(), userid,
|
||||
sp_entity_id, identity)
|
||||
|
||||
return self._response(consumer_url, in_response_to,
|
||||
sp_entity_id, identity, name_id,
|
||||
status, sign, policy=self.conf.aa_policy())
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def authn_response(self, identity, in_response_to, destination, spid,
|
||||
name_id_policy, userid):
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def authn_response(self, identity, in_response_to, destination,
|
||||
sp_entity_id, name_id_policy, userid):
|
||||
""" Constructs an AuthenticationResponse
|
||||
|
||||
:param identity: Information about an user
|
||||
:param in_response_to: The identifier of the authentication request
|
||||
this response is an answer to.
|
||||
:param destination: Where the response should be sent
|
||||
:param sid: The entity identifier of the Service Provider
|
||||
:param sp_entity_id: The entity identifier of the Service Provider
|
||||
:param name_id_policy: ...
|
||||
:param userid: The subject identifier
|
||||
:return: A XML string representing an authentication response
|
||||
"""
|
||||
|
||||
name_id = None
|
||||
if name_id_policy.sp_name_qualifier:
|
||||
try:
|
||||
vo_conf = self.conf.vo_conf(name_id_policy.sp_name_qualifier)
|
||||
subj_id = identity[vo_conf["common_identifier"]]
|
||||
except KeyError:
|
||||
self.log.info(
|
||||
"Get persistent ID (%s,%s)" % (
|
||||
name_id_policy.sp_name_qualifier,userid))
|
||||
subj_id = self.id.persistent(name_id_policy.sp_name_qualifier,
|
||||
userid)
|
||||
self.log.info("=> %s" % subj_id)
|
||||
|
||||
name_id = args2dict(subj_id,
|
||||
format=saml.NAMEID_FORMAT_PERSISTENT,
|
||||
sp_name_qualifier=name_id_policy.sp_name_qualifier)
|
||||
try:
|
||||
name_id = self.id.construct_nameid(self.conf.idp_policy(),
|
||||
userid, sp_entity_id, identity,
|
||||
name_id_policy)
|
||||
except IOError, exc:
|
||||
response = self.error_response(destination, in_response_to,
|
||||
sp_entity_id, exc, name_id)
|
||||
return ("%s" % response).split("\n")
|
||||
|
||||
try:
|
||||
resp = self.do_response(
|
||||
response = self.do_response(
|
||||
destination, # consumer_url
|
||||
in_response_to, # in_response_to
|
||||
spid, # sp_entity_id
|
||||
sp_entity_id, # sp_entity_id
|
||||
identity, # identity as dictionary
|
||||
name_id,
|
||||
userid
|
||||
)
|
||||
except MissingValue, exc:
|
||||
resp = self.error_response( destination, in_response_to, spid,
|
||||
exc, name_id)
|
||||
response = self.error_response(destination, in_response_to,
|
||||
sp_entity_id, exc, name_id)
|
||||
|
||||
|
||||
return ("%s" % resp).split("\n")
|
||||
return ("%s" % response).split("\n")
|
||||
Reference in New Issue
Block a user