Refactored and rewrote the nameID DB handling.

This commit is contained in:
Roland Hedberg 2013-01-15 15:47:27 +01:00
parent 3dff9e5245
commit 402a49ad0c
5 changed files with 271 additions and 170 deletions

View File

@ -86,7 +86,9 @@ AA_IDP_ARGS = ["want_authn_requests_signed",
"endpoints", "endpoints",
"metadata", "metadata",
"ui_info", "ui_info",
"name_id_format" "name_id_format",
"domain",
"name_qualifier"
] ]
PDP_ARGS = ["endpoints", "name_form", "name_id_format"] PDP_ARGS = ["endpoints", "name_form", "name_id_format"]
@ -177,6 +179,8 @@ class Config(object):
self.serves = [] self.serves = []
self.vorg = {} self.vorg = {}
self.preferred_binding = PREFERRED_BINDING self.preferred_binding = PREFERRED_BINDING
self.domain = ""
self.name_qualifier = ""
def setattr(self, context, attr, val): def setattr(self, context, attr, val):
if context == "": if context == "":

View File

@ -119,22 +119,22 @@ class Entity(HTTPBase):
return Issuer(text=self.config.entityid, return Issuer(text=self.config.entityid,
format=NAMEID_FORMAT_ENTITY) format=NAMEID_FORMAT_ENTITY)
def apply_binding(self, binding, req_str, destination, relay_state, def apply_binding(self, binding, msg_str, destination, relay_state,
typ="SAMLRequest"): typ="SAMLRequest"):
if binding == BINDING_HTTP_POST: if binding == BINDING_HTTP_POST:
logger.info("HTTP POST") logger.info("HTTP POST")
info = self.use_http_form_post(req_str, destination, info = self.use_http_form_post(msg_str, destination,
relay_state, typ) relay_state, typ)
info["url"] = destination info["url"] = destination
info["method"] = "GET" info["method"] = "GET"
elif binding == BINDING_HTTP_REDIRECT: elif binding == BINDING_HTTP_REDIRECT:
logger.info("HTTP REDIRECT") logger.info("HTTP REDIRECT")
info = self.use_http_get(req_str, destination, relay_state, typ) info = self.use_http_get(msg_str, destination, relay_state, typ)
info["url"] = destination info["url"] = destination
info["method"] = "GET" info["method"] = "GET"
elif binding == BINDING_SOAP: elif binding == BINDING_SOAP:
info = self.use_soap(req_str, destination) info = self.use_soap(msg_str, destination)
else: else:
raise Exception("Unknown binding type: %s" % binding) raise Exception("Unknown binding type: %s" % binding)
@ -433,13 +433,15 @@ class Entity(HTTPBase):
consent, extensions, sign, name_id=name_id, consent, extensions, sign, name_id=name_id,
reason=reason, not_on_or_after=expire) reason=reason, not_on_or_after=expire)
def create_logout_response(self, request, bindings, status=None, def create_logout_response(self, request, bindings=None, status=None,
sign=False, issuer=None): sign=False, issuer=None):
""" Create a LogoutResponse. """ Create a LogoutResponse.
:param request: The request this is a response to :param request: The request this is a response to
:param bindings: Which bindings that can be used for the response :param bindings: Which bindings that can be used for the response
If None the preferred bindings are gathered from the configuration
:param status: The return status of the response operation :param status: The return status of the response operation
If None the operation is regarded as a Success.
:param issuer: The issuer of the message :param issuer: The issuer of the message
:return: HTTP args :return: HTTP args
""" """

232
src/saml2/ident.py Normal file
View File

@ -0,0 +1,232 @@
import shelve
from hashlib import sha256
from urllib import quote
from urllib import unquote
from saml2.s_utils import rndstr
from saml2.saml import NameID
from saml2.saml import NAMEID_FORMAT_TRANSIENT
from saml2.saml import NAMEID_FORMAT_EMAILADDRESS
__author__ = 'rolandh'
ATTR = ["name_qualifier", "sp_name_qualifier", "format", "sp_provided_id"]
class Unknown(Exception):
pass
def code(item):
_res = []
i = 0
for attr in ATTR:
val = getattr(item, attr)
if val:
_res.append("%d=%s" % (i, quote(val)))
i += 1
return ",".join(_res)
def decode(str):
_nid = NameID()
for part in str.split(","):
i, val = part.split("=")
setattr(_nid, ATTR[int(i)], unquote(val))
return _nid
class IdentDB(object):
""" A class that handles identifiers of entities
Keeps a list of all nameIDs returned per SP
"""
def __init__(self, db, domain="", name_qualifier=""):
if isinstance(db, basestring):
self.db = shelve.open(db)
else:
self.db = db
self.domain = domain
self.name_qualifier = name_qualifier
def _create_id(self, format, name_qualifier="", sp_name_qualifier=""):
_id = sha256(rndstr(32))
_id.update(format)
if name_qualifier:
_id.update(name_qualifier)
if sp_name_qualifier:
_id.update(sp_name_qualifier)
return _id.hexdigest()
def create_id(self, format, name_qualifier="", sp_name_qualifier=""):
_id = self._create_id(format, name_qualifier, sp_name_qualifier)
while _id in self.db:
_id = self._create_id(format, name_qualifier, sp_name_qualifier)
return _id
def store(self, id, name_id):
try:
val = self.db[id].split(" ")
except KeyError:
val = []
_cn = code(name_id)
val.append(_cn)
self.db[id] = " ".join(val)
self.db[_cn] = id
def remove_remote(self, name_id):
_cn = code(name_id)
_id = self.db[_cn]
try:
vals = self.db[_id].split(" ")
vals.remove(_cn)
self.db[id] = " ".join(vals)
except KeyError:
pass
del self.db[_cn]
def remove_local(self, id):
try:
for val in self.db[id].split(" "):
try:
del self.db[val]
except KeyError:
pass
del self.db[id]
except KeyError:
pass
def get_nameid(self, format, sp_name_qualifier, userid, name_qualifier):
_id = self.create_id(format, name_qualifier, sp_name_qualifier)
if format == NAMEID_FORMAT_EMAILADDRESS:
if not self.domain:
raise Exception("Can't issue email nameids, unknown domain")
_id = "%s@%s" % (_id, self.domain)
nameid = NameID(format=format, sp_name_qualifier=sp_name_qualifier,
name_qualifier=name_qualifier, text=_id)
self.store(userid, nameid)
return nameid
def nim_args(self, local_policy=None, sp_name_qualifier="",
name_id_policy=None, name_qualifier=""):
"""
:param local_policy:
:param sp_name_qualifier:
:param name_id_policy:
:param name_qualifier:
:return:
"""
if name_id_policy and name_id_policy.sp_name_qualifier:
sp_name_qualifier = name_id_policy.sp_name_qualifier
else:
sp_name_qualifier = sp_name_qualifier
if name_id_policy:
nameid_format = name_id_policy.format
elif local_policy:
nameid_format = local_policy.get_nameid_format(sp_name_qualifier)
else:
raise Exception("Unknown NameID format")
if not name_qualifier:
name_qualifier = self.name_qualifier
return {"format":nameid_format, "sp_name_qualifier": sp_name_qualifier,
"name_qualifier":name_qualifier}
def construct_nameid(self, userid, local_policy=None,
sp_name_qualifier=None, name_id_policy=None,
sp_nid=None, name_qualifier=""):
""" Returns a name_id for the object. How the name_id is
constructed depends on the context.
:param local_policy: The policy the server is configured to follow
:param userid: The local permanent identifier of the object
:param sp_name_qualifier: The 'user'/-s of the name_id
:param name_id_policy: The policy the server on the other side wants
us to follow.
:param sp_nid: Name ID Formats from the SPs metadata
:return: NameID instance precursor
"""
args = self.nim_args(local_policy, sp_name_qualifier, name_id_policy)
return self.get_nameid(userid, **args)
def find_local_id(self, name_id):
"""
Only find on persistent IDs
:param name_id:
:return:
"""
try:
return self.db[code(name_id)]
except KeyError:
return None
def match_local_id(self, userid, sp_name_qualifier, name_qualifier):
try:
for val in self.db[userid].split(" "):
nid = decode(val)
if nid.format == NAMEID_FORMAT_TRANSIENT:
continue
if getattr(nid, "sp_name_qualifier", "") == sp_name_qualifier:
if getattr(nid, "name_qualifier", "") == name_qualifier:
return nid
except KeyError:
pass
return None
def handle_name_id_mapping_request(self, name_id, name_id_policy):
"""
:param name_id: The NameID that specifies the principal
:param name_id_policy: The NameIDPolicy of the requester
:return: If an old name_id exists that match the name-id policy
that is return otherwise if a new one can be created it
will be and returned. If no old matching exists and a new
is not allowed to be created None is returned.
"""
_id = self.find_local_id(name_id)
if not _id:
raise Unknown("Unknown entity")
# return an old one if present
for val in self.db[_id].split(" "):
_nid = decode(val)
if _nid.format == name_id_policy.format:
if _nid.sp_name_qualifier == name_id_policy.sp_name_qualifier:
return _nid
if name_id_policy.allow_create == "false":
return None
# else create and return a new one
return self.construct_nameid(_id, name_id_policy=name_id_policy)
def publish(self, userid, name_id, entity_id):
"""
About userid I have published nameid to entity_id
Will gladly overwrite whatever was there before
:param userid:
:param name_id:
:param entity_id:
:return:
"""
self.db["%s:%s" % (userid, entity_id)] = name_id
def published(self, userid, entity_id):
"""
:param userid:
:param entity_id:
:return:
"""
try:
return self.db["%s:%s" % (userid, entity_id)]
except KeyError:
return None

View File

@ -23,6 +23,8 @@ import logging
import shelve import shelve
import sys import sys
import memcache import memcache
from hashlib import sha1
from saml2.samlp import NameIDMappingResponse from saml2.samlp import NameIDMappingResponse
from saml2.entity import Entity from saml2.entity import Entity
@ -37,7 +39,6 @@ from saml2.request import NameIDMappingRequest
from saml2.request import AuthzDecisionQuery from saml2.request import AuthzDecisionQuery
from saml2.request import AuthnQuery from saml2.request import AuthnQuery
from saml2.s_utils import sid
from saml2.s_utils import MissingValue from saml2.s_utils import MissingValue
from saml2.s_utils import error_status_factory from saml2.s_utils import error_status_factory
@ -48,171 +49,16 @@ from saml2.assertion import Policy
from saml2.assertion import restriction_from_attribute_spec from saml2.assertion import restriction_from_attribute_spec
from saml2.assertion import filter_attribute_value_assertions from saml2.assertion import filter_attribute_value_assertions
from saml2.ident import IdentDB
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnknownVO(Exception):
pass
def context_match(cfilter, cntx): def context_match(cfilter, cntx):
# TODO
return True return True
class Identifier(object):
""" A class that handles identifiers of objects """
def __init__(self, db, voconf=None):
if isinstance(db, basestring):
self.map = shelve.open(db, writeback=True)
else:
self.map = db
self.voconf = voconf
def _store(self, typ, entity_id, local, remote):
self.map["|".join([typ, entity_id, "f", local])] = remote
self.map["|".join([typ, entity_id, "b", remote])] = local
def _get_remote(self, typ, entity_id, local):
return self.map["|".join([typ, entity_id, "f", local])]
def _get_local(self, typ, entity_id, remote):
return self.map["|".join([typ, entity_id, "b", remote])]
def persistent(self, entity_id, subject_id):
""" Keeps the link between a permanent identifier and a
temporary/pseudo-temporary identifier for a subject
The store supports look-up both ways: from a permanent local
identifier to a identifier used talking to a SP and from an
identifier given back by an SP to the local permanent.
:param entity_id: SP entity ID or VO entity ID
:param subject_id: The local permanent identifier of the subject
:return: An arbitrary identifier for the subject unique to the
service/group of services/VO with a given entity_id
"""
try:
return self._get_remote("persistent", entity_id, subject_id)
except KeyError:
temp_id = "xyz"
while True:
temp_id = sid()
try:
self._get_local("persistent", entity_id, temp_id)
except KeyError:
break
self._store("persistent", entity_id, subject_id, temp_id)
self.map.sync()
return temp_id
def _get_vo_identifier(self, sp_name_qualifier, identity):
try:
vo = self.voconf[sp_name_qualifier]
try:
subj_id = identity[vo.common_identifier]
except KeyError:
raise MissingValue("Common identifier")
except (KeyError, TypeError):
raise UnknownVO("%s" % sp_name_qualifier)
nameid_format = vo.nameid_format
if not nameid_format:
nameid_format = saml.NAMEID_FORMAT_PERSISTENT
return saml.NameID(format=nameid_format,
sp_name_qualifier=sp_name_qualifier,
text=subj_id)
def persistent_nameid(self, sp_name_qualifier, userid):
""" Get or create a persistent identifier for this object to be used
when communicating with servers using a specific SPNameQualifier
:param sp_name_qualifier: An identifier for a 'context'
:param userid: The local permanent identifier of the object
:return: A persistent random identifier.
"""
subj_id = self.persistent(sp_name_qualifier, userid)
return saml.NameID(format=saml.NAMEID_FORMAT_PERSISTENT,
sp_name_qualifier=sp_name_qualifier,
text=subj_id)
def transient_nameid(self, sp_entity_id, userid):
""" Returns a random one-time identifier. One-time means it is
kept around as long as the session is active.
:param sp_entity_id: A qualifier to bind the created identifier to
:param userid: The local persistent identifier for the subject.
:return: The created identifier,
"""
temp_id = sid()
while True:
try:
_ = self._get_local("transient", sp_entity_id, temp_id)
temp_id = sid()
except KeyError:
break
self._store("transient", sp_entity_id, userid, temp_id)
self.map.sync()
return saml.NameID(format=saml.NAMEID_FORMAT_TRANSIENT,
sp_name_qualifier=sp_entity_id,
text=temp_id)
def email_nameid(self, sp_name_qualifier, userid):
return saml.NameID(format=saml.NAMEID_FORMAT_EMAILADDRESS,
sp_name_qualifier=sp_name_qualifier,
text=userid)
def construct_nameid(self, local_policy, userid, sp_entity_id,
identity=None, name_id_policy=None, sp_nid=None):
""" Returns a name_id for the object. How the name_id is
constructed depends on the context.
:param local_policy: The policy the server is configured to follow
:param userid: The local permanent identifier of the object
:param sp_entity_id: The 'user' of the name_id
:param identity: Attribute/value pairs describing the object
:param name_id_policy: The policy the server on the other side wants
us to follow.
:param sp_nid: Name ID Formats from the SPs metadata
:return: NameID instance precursor
"""
if name_id_policy and name_id_policy.sp_name_qualifier:
try:
return self._get_vo_identifier(name_id_policy.sp_name_qualifier,
identity)
except Exception, exc:
print >> sys.stderr, "%s:%s" % (exc.__class__.__name__, exc)
if name_id_policy:
nameid_format = name_id_policy.format
elif sp_nid:
nameid_format = sp_nid[0]
elif local_policy:
nameid_format = local_policy.get_nameid_format(sp_entity_id)
else:
raise Exception("Unknown NameID format")
if nameid_format == saml.NAMEID_FORMAT_PERSISTENT:
return self.persistent_nameid(sp_entity_id, userid)
elif nameid_format == saml.NAMEID_FORMAT_TRANSIENT:
return self.transient_nameid(sp_entity_id, userid)
elif nameid_format == saml.NAMEID_FORMAT_EMAILADDRESS:
return self.email_nameid(sp_entity_id, userid)
def local_name(self, entity_id, remote_id):
""" Get the local persistent name that has the specified remote ID.
:param entity_id: The identifier of the entity that got the remote id
:param remote_id: The identifier that was exported
:return: Local identifier
"""
try:
return self._get_local("persistent", entity_id, remote_id)
except KeyError:
try:
return self._get_local("transient", entity_id, remote_id)
except KeyError:
return None
class Server(Entity): class Server(Entity):
""" A class that does things that IdPs or AAs do """ """ A class that does things that IdPs or AAs do """
def __init__(self, config_file="", config=None, _cache="", stype="idp"): def __init__(self, config_file="", config=None, _cache="", stype="idp"):
@ -249,7 +95,7 @@ class Server(Entity):
idb = addr idb = addr
if idb is not None: if idb is not None:
self.ident = Identifier(idb, self.config.virtual_organization) self.ident = IdentDB(idb)
else: else:
raise Exception("Couldn't open identity database: %s" % raise Exception("Couldn't open identity database: %s" %
(dbspec,)) (dbspec,))
@ -349,16 +195,31 @@ class Server(Entity):
def get_assertion(self, id): def get_assertion(self, id):
return self.assertion[id] return self.assertion[id]
def store_authn_statement(self, authn_statement, name_id): def store_authn_statement(self, authn_statement, subject):
"""
:param authn_statement:
:param subject:
:return:
"""
key = sha1("%s" % subject).digest()
try: try:
self.authn[name_id.text].append(authn_statement) self.authn[key].append(authn_statement)
except: except:
self.authn[name_id.text] = [authn_statement] self.authn[key] = [authn_statement]
def get_authn_statements(self, subject, session_index=None, def get_authn_statements(self, subject, session_index=None,
requested_context=None): requested_context=None):
"""
:param subject:
:param session_index:
:param requested_context:
:return:
"""
result = [] result = []
for statement in self.authn[subject.name_id.text]: key = sha1("%s" % subject).digest()
for statement in self.authn[key]:
if session_index: if session_index:
if statement.session_index != session_index: if statement.session_index != session_index:
continue continue

View File

@ -38,6 +38,8 @@ CONFIG = {
} }
}, },
"subject_data": "subject_data.db", "subject_data": "subject_data.db",
#"domain": "umu.se",
#"name_qualifier": ""
}, },
}, },
"debug" : 1, "debug" : 1,