Refactored and rewrote the nameID DB handling.
This commit is contained in:
parent
3dff9e5245
commit
402a49ad0c
@ -86,7 +86,9 @@ AA_IDP_ARGS = ["want_authn_requests_signed",
|
||||
"endpoints",
|
||||
"metadata",
|
||||
"ui_info",
|
||||
"name_id_format"
|
||||
"name_id_format",
|
||||
"domain",
|
||||
"name_qualifier"
|
||||
]
|
||||
|
||||
PDP_ARGS = ["endpoints", "name_form", "name_id_format"]
|
||||
@ -177,6 +179,8 @@ class Config(object):
|
||||
self.serves = []
|
||||
self.vorg = {}
|
||||
self.preferred_binding = PREFERRED_BINDING
|
||||
self.domain = ""
|
||||
self.name_qualifier = ""
|
||||
|
||||
def setattr(self, context, attr, val):
|
||||
if context == "":
|
||||
|
@ -119,22 +119,22 @@ class Entity(HTTPBase):
|
||||
return Issuer(text=self.config.entityid,
|
||||
format=NAMEID_FORMAT_ENTITY)
|
||||
|
||||
def apply_binding(self, binding, req_str, destination, relay_state,
|
||||
def apply_binding(self, binding, msg_str, destination, relay_state,
|
||||
typ="SAMLRequest"):
|
||||
|
||||
if binding == BINDING_HTTP_POST:
|
||||
logger.info("HTTP POST")
|
||||
info = self.use_http_form_post(req_str, destination,
|
||||
info = self.use_http_form_post(msg_str, destination,
|
||||
relay_state, typ)
|
||||
info["url"] = destination
|
||||
info["method"] = "GET"
|
||||
elif binding == BINDING_HTTP_REDIRECT:
|
||||
logger.info("HTTP REDIRECT")
|
||||
info = self.use_http_get(req_str, destination, relay_state, typ)
|
||||
info = self.use_http_get(msg_str, destination, relay_state, typ)
|
||||
info["url"] = destination
|
||||
info["method"] = "GET"
|
||||
elif binding == BINDING_SOAP:
|
||||
info = self.use_soap(req_str, destination)
|
||||
info = self.use_soap(msg_str, destination)
|
||||
else:
|
||||
raise Exception("Unknown binding type: %s" % binding)
|
||||
|
||||
@ -433,13 +433,15 @@ class Entity(HTTPBase):
|
||||
consent, extensions, sign, name_id=name_id,
|
||||
reason=reason, not_on_or_after=expire)
|
||||
|
||||
def create_logout_response(self, request, bindings, status=None,
|
||||
def create_logout_response(self, request, bindings=None, status=None,
|
||||
sign=False, issuer=None):
|
||||
""" Create a LogoutResponse.
|
||||
|
||||
:param request: The request this is a response to
|
||||
:param bindings: Which bindings that can be used for the response
|
||||
If None the preferred bindings are gathered from the configuration
|
||||
:param status: The return status of the response operation
|
||||
If None the operation is regarded as a Success.
|
||||
:param issuer: The issuer of the message
|
||||
:return: HTTP args
|
||||
"""
|
||||
|
232
src/saml2/ident.py
Normal file
232
src/saml2/ident.py
Normal file
@ -0,0 +1,232 @@
|
||||
import shelve
|
||||
from hashlib import sha256
|
||||
from urllib import quote
|
||||
from urllib import unquote
|
||||
from saml2.s_utils import rndstr
|
||||
from saml2.saml import NameID
|
||||
from saml2.saml import NAMEID_FORMAT_TRANSIENT
|
||||
from saml2.saml import NAMEID_FORMAT_EMAILADDRESS
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id"]
|
||||
|
||||
class Unknown(Exception):
|
||||
pass
|
||||
|
||||
def code(item):
|
||||
_res = []
|
||||
i = 0
|
||||
for attr in ATTR:
|
||||
val = getattr(item, attr)
|
||||
if val:
|
||||
_res.append("%d=%s" % (i, quote(val)))
|
||||
i += 1
|
||||
return ",".join(_res)
|
||||
|
||||
def decode(str):
|
||||
_nid = NameID()
|
||||
for part in str.split(","):
|
||||
i, val = part.split("=")
|
||||
setattr(_nid, ATTR[int(i)], unquote(val))
|
||||
return _nid
|
||||
|
||||
class IdentDB(object):
|
||||
""" A class that handles identifiers of entities
|
||||
Keeps a list of all nameIDs returned per SP
|
||||
"""
|
||||
def __init__(self, db, domain="", name_qualifier=""):
|
||||
if isinstance(db, basestring):
|
||||
self.db = shelve.open(db)
|
||||
else:
|
||||
self.db = db
|
||||
self.domain = domain
|
||||
self.name_qualifier = name_qualifier
|
||||
|
||||
def _create_id(self, format, name_qualifier="", sp_name_qualifier=""):
|
||||
_id = sha256(rndstr(32))
|
||||
_id.update(format)
|
||||
if name_qualifier:
|
||||
_id.update(name_qualifier)
|
||||
if sp_name_qualifier:
|
||||
_id.update(sp_name_qualifier)
|
||||
return _id.hexdigest()
|
||||
|
||||
def create_id(self, format, name_qualifier="", sp_name_qualifier=""):
|
||||
_id = self._create_id(format, name_qualifier, sp_name_qualifier)
|
||||
while _id in self.db:
|
||||
_id = self._create_id(format, name_qualifier, sp_name_qualifier)
|
||||
return _id
|
||||
|
||||
def store(self, id, name_id):
|
||||
try:
|
||||
val = self.db[id].split(" ")
|
||||
except KeyError:
|
||||
val = []
|
||||
|
||||
_cn = code(name_id)
|
||||
val.append(_cn)
|
||||
self.db[id] = " ".join(val)
|
||||
self.db[_cn] = id
|
||||
|
||||
def remove_remote(self, name_id):
|
||||
_cn = code(name_id)
|
||||
_id = self.db[_cn]
|
||||
try:
|
||||
vals = self.db[_id].split(" ")
|
||||
vals.remove(_cn)
|
||||
self.db[id] = " ".join(vals)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
del self.db[_cn]
|
||||
|
||||
def remove_local(self, id):
|
||||
try:
|
||||
for val in self.db[id].split(" "):
|
||||
try:
|
||||
del self.db[val]
|
||||
except KeyError:
|
||||
pass
|
||||
del self.db[id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def get_nameid(self, format, sp_name_qualifier, userid, name_qualifier):
|
||||
_id = self.create_id(format, name_qualifier, sp_name_qualifier)
|
||||
|
||||
if format == NAMEID_FORMAT_EMAILADDRESS:
|
||||
if not self.domain:
|
||||
raise Exception("Can't issue email nameids, unknown domain")
|
||||
|
||||
_id = "%s@%s" % (_id, self.domain)
|
||||
|
||||
nameid = NameID(format=format, sp_name_qualifier=sp_name_qualifier,
|
||||
name_qualifier=name_qualifier, text=_id)
|
||||
|
||||
self.store(userid, nameid)
|
||||
return nameid
|
||||
|
||||
def nim_args(self, local_policy=None, sp_name_qualifier="",
|
||||
name_id_policy=None, name_qualifier=""):
|
||||
"""
|
||||
|
||||
:param local_policy:
|
||||
:param sp_name_qualifier:
|
||||
:param name_id_policy:
|
||||
:param name_qualifier:
|
||||
:return:
|
||||
"""
|
||||
if name_id_policy and name_id_policy.sp_name_qualifier:
|
||||
sp_name_qualifier = name_id_policy.sp_name_qualifier
|
||||
else:
|
||||
sp_name_qualifier = sp_name_qualifier
|
||||
|
||||
if name_id_policy:
|
||||
nameid_format = name_id_policy.format
|
||||
elif local_policy:
|
||||
nameid_format = local_policy.get_nameid_format(sp_name_qualifier)
|
||||
else:
|
||||
raise Exception("Unknown NameID format")
|
||||
|
||||
if not name_qualifier:
|
||||
name_qualifier = self.name_qualifier
|
||||
|
||||
return {"format":nameid_format, "sp_name_qualifier": sp_name_qualifier,
|
||||
"name_qualifier":name_qualifier}
|
||||
|
||||
def construct_nameid(self, userid, local_policy=None,
|
||||
sp_name_qualifier=None, name_id_policy=None,
|
||||
sp_nid=None, name_qualifier=""):
|
||||
""" Returns a name_id for the object. How the name_id is
|
||||
constructed depends on the context.
|
||||
|
||||
:param local_policy: The policy the server is configured to follow
|
||||
:param userid: The local permanent identifier of the object
|
||||
:param sp_name_qualifier: The 'user'/-s of the name_id
|
||||
:param name_id_policy: The policy the server on the other side wants
|
||||
us to follow.
|
||||
:param sp_nid: Name ID Formats from the SPs metadata
|
||||
:return: NameID instance precursor
|
||||
"""
|
||||
|
||||
args = self.nim_args(local_policy, sp_name_qualifier, name_id_policy)
|
||||
return self.get_nameid(userid, **args)
|
||||
|
||||
def find_local_id(self, name_id):
|
||||
"""
|
||||
Only find on persistent IDs
|
||||
|
||||
:param name_id:
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.db[code(name_id)]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def match_local_id(self, userid, sp_name_qualifier, name_qualifier):
|
||||
try:
|
||||
for val in self.db[userid].split(" "):
|
||||
nid = decode(val)
|
||||
if nid.format == NAMEID_FORMAT_TRANSIENT:
|
||||
continue
|
||||
if getattr(nid, "sp_name_qualifier", "") == sp_name_qualifier:
|
||||
if getattr(nid, "name_qualifier", "") == name_qualifier:
|
||||
return nid
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def handle_name_id_mapping_request(self, name_id, name_id_policy):
|
||||
"""
|
||||
|
||||
:param name_id: The NameID that specifies the principal
|
||||
:param name_id_policy: The NameIDPolicy of the requester
|
||||
:return: If an old name_id exists that match the name-id policy
|
||||
that is return otherwise if a new one can be created it
|
||||
will be and returned. If no old matching exists and a new
|
||||
is not allowed to be created None is returned.
|
||||
"""
|
||||
_id = self.find_local_id(name_id)
|
||||
if not _id:
|
||||
raise Unknown("Unknown entity")
|
||||
|
||||
# return an old one if present
|
||||
for val in self.db[_id].split(" "):
|
||||
_nid = decode(val)
|
||||
if _nid.format == name_id_policy.format:
|
||||
if _nid.sp_name_qualifier == name_id_policy.sp_name_qualifier:
|
||||
return _nid
|
||||
|
||||
if name_id_policy.allow_create == "false":
|
||||
return None
|
||||
|
||||
# else create and return a new one
|
||||
return self.construct_nameid(_id, name_id_policy=name_id_policy)
|
||||
|
||||
def publish(self, userid, name_id, entity_id):
|
||||
"""
|
||||
About userid I have published nameid to entity_id
|
||||
Will gladly overwrite whatever was there before
|
||||
:param userid:
|
||||
:param name_id:
|
||||
:param entity_id:
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.db["%s:%s" % (userid, entity_id)] = name_id
|
||||
|
||||
def published(self, userid, entity_id):
|
||||
"""
|
||||
|
||||
:param userid:
|
||||
:param entity_id:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
return self.db["%s:%s" % (userid, entity_id)]
|
||||
except KeyError:
|
||||
return None
|
@ -23,6 +23,8 @@ import logging
|
||||
import shelve
|
||||
import sys
|
||||
import memcache
|
||||
from hashlib import sha1
|
||||
|
||||
from saml2.samlp import NameIDMappingResponse
|
||||
from saml2.entity import Entity
|
||||
|
||||
@ -37,7 +39,6 @@ from saml2.request import NameIDMappingRequest
|
||||
from saml2.request import AuthzDecisionQuery
|
||||
from saml2.request import AuthnQuery
|
||||
|
||||
from saml2.s_utils import sid
|
||||
from saml2.s_utils import MissingValue
|
||||
from saml2.s_utils import error_status_factory
|
||||
|
||||
@ -48,171 +49,16 @@ from saml2.assertion import Policy
|
||||
from saml2.assertion import restriction_from_attribute_spec
|
||||
from saml2.assertion import filter_attribute_value_assertions
|
||||
|
||||
from saml2.ident import IdentDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UnknownVO(Exception):
|
||||
pass
|
||||
|
||||
def context_match(cfilter, cntx):
|
||||
# TODO
|
||||
return True
|
||||
|
||||
class Identifier(object):
|
||||
""" A class that handles identifiers of objects """
|
||||
def __init__(self, db, voconf=None):
|
||||
if isinstance(db, basestring):
|
||||
self.map = shelve.open(db, writeback=True)
|
||||
else:
|
||||
self.map = db
|
||||
self.voconf = voconf
|
||||
|
||||
def _store(self, typ, entity_id, local, remote):
|
||||
self.map["|".join([typ, entity_id, "f", local])] = remote
|
||||
self.map["|".join([typ, entity_id, "b", remote])] = local
|
||||
|
||||
def _get_remote(self, typ, entity_id, local):
|
||||
return self.map["|".join([typ, entity_id, "f", local])]
|
||||
|
||||
def _get_local(self, typ, entity_id, remote):
|
||||
return self.map["|".join([typ, entity_id, "b", remote])]
|
||||
|
||||
def persistent(self, entity_id, subject_id):
|
||||
""" Keeps the link between a permanent identifier and a
|
||||
temporary/pseudo-temporary identifier for a subject
|
||||
|
||||
The store supports look-up both ways: from a permanent local
|
||||
identifier to a identifier used talking to a SP and from an
|
||||
identifier given back by an SP to the local permanent.
|
||||
|
||||
:param entity_id: SP entity ID or VO entity ID
|
||||
:param subject_id: The local permanent identifier of the subject
|
||||
:return: An arbitrary identifier for the subject unique to the
|
||||
service/group of services/VO with a given entity_id
|
||||
"""
|
||||
try:
|
||||
return self._get_remote("persistent", entity_id, subject_id)
|
||||
except KeyError:
|
||||
temp_id = "xyz"
|
||||
while True:
|
||||
temp_id = sid()
|
||||
try:
|
||||
self._get_local("persistent", entity_id, temp_id)
|
||||
except KeyError:
|
||||
break
|
||||
self._store("persistent", entity_id, subject_id, temp_id)
|
||||
self.map.sync()
|
||||
|
||||
return temp_id
|
||||
|
||||
def _get_vo_identifier(self, sp_name_qualifier, identity):
|
||||
try:
|
||||
vo = self.voconf[sp_name_qualifier]
|
||||
try:
|
||||
subj_id = identity[vo.common_identifier]
|
||||
except KeyError:
|
||||
raise MissingValue("Common identifier")
|
||||
except (KeyError, TypeError):
|
||||
raise UnknownVO("%s" % sp_name_qualifier)
|
||||
|
||||
nameid_format = vo.nameid_format
|
||||
if not nameid_format:
|
||||
nameid_format = saml.NAMEID_FORMAT_PERSISTENT
|
||||
|
||||
return saml.NameID(format=nameid_format,
|
||||
sp_name_qualifier=sp_name_qualifier,
|
||||
text=subj_id)
|
||||
|
||||
def persistent_nameid(self, sp_name_qualifier, userid):
|
||||
""" Get or create a persistent identifier for this object to be used
|
||||
when communicating with servers using a specific SPNameQualifier
|
||||
|
||||
:param sp_name_qualifier: An identifier for a 'context'
|
||||
:param userid: The local permanent identifier of the object
|
||||
:return: A persistent random identifier.
|
||||
"""
|
||||
subj_id = self.persistent(sp_name_qualifier, userid)
|
||||
return saml.NameID(format=saml.NAMEID_FORMAT_PERSISTENT,
|
||||
sp_name_qualifier=sp_name_qualifier,
|
||||
text=subj_id)
|
||||
|
||||
def transient_nameid(self, sp_entity_id, userid):
|
||||
""" Returns a random one-time identifier. One-time means it is
|
||||
kept around as long as the session is active.
|
||||
|
||||
:param sp_entity_id: A qualifier to bind the created identifier to
|
||||
:param userid: The local persistent identifier for the subject.
|
||||
:return: The created identifier,
|
||||
"""
|
||||
temp_id = sid()
|
||||
while True:
|
||||
try:
|
||||
_ = self._get_local("transient", sp_entity_id, temp_id)
|
||||
temp_id = sid()
|
||||
except KeyError:
|
||||
break
|
||||
self._store("transient", sp_entity_id, userid, temp_id)
|
||||
self.map.sync()
|
||||
|
||||
return saml.NameID(format=saml.NAMEID_FORMAT_TRANSIENT,
|
||||
sp_name_qualifier=sp_entity_id,
|
||||
text=temp_id)
|
||||
|
||||
def email_nameid(self, sp_name_qualifier, userid):
|
||||
return saml.NameID(format=saml.NAMEID_FORMAT_EMAILADDRESS,
|
||||
sp_name_qualifier=sp_name_qualifier,
|
||||
text=userid)
|
||||
|
||||
def construct_nameid(self, local_policy, userid, sp_entity_id,
|
||||
identity=None, name_id_policy=None, sp_nid=None):
|
||||
""" Returns a name_id for the object. How the name_id is
|
||||
constructed depends on the context.
|
||||
|
||||
:param local_policy: The policy the server is configured to follow
|
||||
:param userid: The local permanent identifier of the object
|
||||
:param sp_entity_id: The 'user' of the name_id
|
||||
:param identity: Attribute/value pairs describing the object
|
||||
:param name_id_policy: The policy the server on the other side wants
|
||||
us to follow.
|
||||
:param sp_nid: Name ID Formats from the SPs metadata
|
||||
:return: NameID instance precursor
|
||||
"""
|
||||
if name_id_policy and name_id_policy.sp_name_qualifier:
|
||||
try:
|
||||
return self._get_vo_identifier(name_id_policy.sp_name_qualifier,
|
||||
identity)
|
||||
except Exception, exc:
|
||||
print >> sys.stderr, "%s:%s" % (exc.__class__.__name__, exc)
|
||||
|
||||
if name_id_policy:
|
||||
nameid_format = name_id_policy.format
|
||||
elif sp_nid:
|
||||
nameid_format = sp_nid[0]
|
||||
elif local_policy:
|
||||
nameid_format = local_policy.get_nameid_format(sp_entity_id)
|
||||
else:
|
||||
raise Exception("Unknown NameID format")
|
||||
|
||||
if nameid_format == saml.NAMEID_FORMAT_PERSISTENT:
|
||||
return self.persistent_nameid(sp_entity_id, userid)
|
||||
elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT:
|
||||
return self.transient_nameid(sp_entity_id, userid)
|
||||
elif nameid_format == saml.NAMEID_FORMAT_EMAILADDRESS:
|
||||
return self.email_nameid(sp_entity_id, userid)
|
||||
|
||||
def local_name(self, entity_id, remote_id):
|
||||
""" Get the local persistent name that has the specified remote ID.
|
||||
|
||||
:param entity_id: The identifier of the entity that got the remote id
|
||||
:param remote_id: The identifier that was exported
|
||||
:return: Local identifier
|
||||
"""
|
||||
try:
|
||||
return self._get_local("persistent", entity_id, remote_id)
|
||||
except KeyError:
|
||||
try:
|
||||
return self._get_local("transient", entity_id, remote_id)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
class Server(Entity):
|
||||
""" A class that does things that IdPs or AAs do """
|
||||
def __init__(self, config_file="", config=None, _cache="", stype="idp"):
|
||||
@ -249,7 +95,7 @@ class Server(Entity):
|
||||
idb = addr
|
||||
|
||||
if idb is not None:
|
||||
self.ident = Identifier(idb, self.config.virtual_organization)
|
||||
self.ident = IdentDB(idb)
|
||||
else:
|
||||
raise Exception("Couldn't open identity database: %s" %
|
||||
(dbspec,))
|
||||
@ -349,16 +195,31 @@ class Server(Entity):
|
||||
def get_assertion(self, id):
|
||||
return self.assertion[id]
|
||||
|
||||
def store_authn_statement(self, authn_statement, name_id):
|
||||
def store_authn_statement(self, authn_statement, subject):
|
||||
"""
|
||||
|
||||
:param authn_statement:
|
||||
:param subject:
|
||||
:return:
|
||||
"""
|
||||
key = sha1("%s" % subject).digest()
|
||||
try:
|
||||
self.authn[name_id.text].append(authn_statement)
|
||||
self.authn[key].append(authn_statement)
|
||||
except:
|
||||
self.authn[name_id.text] = [authn_statement]
|
||||
self.authn[key] = [authn_statement]
|
||||
|
||||
def get_authn_statements(self, subject, session_index=None,
|
||||
requested_context=None):
|
||||
"""
|
||||
|
||||
:param subject:
|
||||
:param session_index:
|
||||
:param requested_context:
|
||||
:return:
|
||||
"""
|
||||
result = []
|
||||
for statement in self.authn[subject.name_id.text]:
|
||||
key = sha1("%s" % subject).digest()
|
||||
for statement in self.authn[key]:
|
||||
if session_index:
|
||||
if statement.session_index != session_index:
|
||||
continue
|
||||
|
@ -38,6 +38,8 @@ CONFIG = {
|
||||
}
|
||||
},
|
||||
"subject_data": "subject_data.db",
|
||||
#"domain": "umu.se",
|
||||
#"name_qualifier": ""
|
||||
},
|
||||
},
|
||||
"debug" : 1,
|
||||
|
Loading…
Reference in New Issue
Block a user