Merge branch 'master' of https://github.com/rohe/pysaml2
This commit is contained in:
@@ -31,6 +31,7 @@ from saml2.saml import AUTHN_PASSWORD
|
||||
|
||||
logger = logging.getLogger("saml2.idp")
|
||||
|
||||
|
||||
def _expiration(timeout, tformat="%a, %d-%b-%Y %H:%M:%S GMT"):
|
||||
"""
|
||||
|
||||
@@ -50,7 +51,7 @@ def _expiration(timeout, tformat="%a, %d-%b-%Y %H:%M:%S GMT"):
|
||||
|
||||
|
||||
def dict2list_of_tuples(d):
|
||||
return [(k,v) for k,v in d.items()]
|
||||
return [(k, v) for k, v in d.items()]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -65,7 +66,7 @@ class Service(object):
|
||||
def unpack_redirect(self):
|
||||
if "QUERY_STRING" in self.environ:
|
||||
_qs = self.environ["QUERY_STRING"]
|
||||
return dict([(k,v[0]) for k,v in parse_qs(_qs).items()])
|
||||
return dict([(k, v[0]) for k, v in parse_qs(_qs).items()])
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -270,8 +271,13 @@ class SSO(Service):
|
||||
IDP.ticket[key] = _dict
|
||||
_resp = key
|
||||
else:
|
||||
_resp = IDP.ticket[_dict["key"]]
|
||||
del IDP.ticket[_dict["key"]]
|
||||
try:
|
||||
_resp = IDP.ticket[_dict["key"]]
|
||||
del IDP.ticket[_dict["key"]]
|
||||
except KeyError:
|
||||
key = sha1("%s" % _dict).hexdigest()
|
||||
IDP.ticket[key] = _dict
|
||||
_resp = key
|
||||
|
||||
return _resp
|
||||
|
||||
@@ -391,8 +397,9 @@ def do_verify(environ, start_response, _):
|
||||
if not _ok:
|
||||
resp = Unauthorized("Unknown user or wrong password")
|
||||
else:
|
||||
uid = rndstr()
|
||||
IDP.authn[uid] = user
|
||||
uid = rndstr(24)
|
||||
IDP.uid2user[uid] = user
|
||||
IDP.user2uid[user] = uid
|
||||
logger.debug("Register %s under '%s'" % (user, uid))
|
||||
kaka = set_cookie("idpauthn", "/", uid)
|
||||
lox = "http://%s%s?id=%s&key=%s" % (environ["HTTP_HOST"],
|
||||
@@ -437,6 +444,8 @@ class SLO(Service):
|
||||
if msg.name_id:
|
||||
lid = IDP.ident.find_local_id(msg.name_id)
|
||||
logger.info("local identifier: %s" % lid)
|
||||
del IDP.uid2user[IDP.user2uid[lid]]
|
||||
del IDP.user2uid[lid]
|
||||
# remove the authentication
|
||||
try:
|
||||
IDP.remove_authn_statements(msg.name_id)
|
||||
@@ -445,7 +454,7 @@ class SLO(Service):
|
||||
resp = ServiceError("%s" % exc)
|
||||
return resp(self.environ, self.start_response)
|
||||
|
||||
resp = IDP.create_logout_response(msg)
|
||||
resp = IDP.create_logout_response(msg, [binding])
|
||||
|
||||
try:
|
||||
hinfo = IDP.apply_binding(binding, "%s" % resp, "", relay_state)
|
||||
@@ -454,11 +463,11 @@ class SLO(Service):
|
||||
resp = ServiceError("%s" % exc)
|
||||
return resp(self.environ, self.start_response)
|
||||
|
||||
logger.info("Header: %s" % (hinfo["headers"],))
|
||||
#_tlh = dict2list_of_tuples(hinfo["headers"])
|
||||
delco = delete_cookie(self.environ, "idpauthn")
|
||||
if delco:
|
||||
hinfo["headers"].append(delco)
|
||||
logger.info("Header: %s" % (hinfo["headers"],))
|
||||
resp = Response(hinfo["data"], headers=hinfo["headers"])
|
||||
return resp(self.environ, self.start_response)
|
||||
|
||||
@@ -475,9 +484,9 @@ class NMI(Service):
|
||||
request = req.message
|
||||
|
||||
# Do the necessary stuff
|
||||
name_id = IDP.ident.handle_manage_name_id_request(request.name_id,
|
||||
request.new_id, request.new_encrypted_id,
|
||||
request.terminate)
|
||||
name_id = IDP.ident.handle_manage_name_id_request(
|
||||
request.name_id, request.new_id, request.new_encrypted_id,
|
||||
request.terminate)
|
||||
|
||||
logger.debug("New NameID: %s" % name_id)
|
||||
|
||||
@@ -606,8 +615,8 @@ class NIM(Service):
|
||||
request = req.message
|
||||
# Do the necessary stuff
|
||||
try:
|
||||
name_id = IDP.ident.handle_name_id_mapping_request(request.name_id,
|
||||
request.name_id_policy)
|
||||
name_id = IDP.ident.handle_name_id_mapping_request(
|
||||
request.name_id, request.name_id_policy)
|
||||
except Unknown:
|
||||
resp = BadRequest("Unknown entity")
|
||||
return resp(self.environ, self.start_response)
|
||||
@@ -638,7 +647,10 @@ def kaka2user(kaka):
|
||||
cookie_obj = SimpleCookie(kaka)
|
||||
morsel = cookie_obj.get("idpauthn", None)
|
||||
if morsel:
|
||||
return IDP.authn[morsel.value]
|
||||
try:
|
||||
return IDP.uid2user[morsel.value]
|
||||
except KeyError:
|
||||
return None
|
||||
else:
|
||||
logger.debug("No idpauthn cookie")
|
||||
return None
|
||||
@@ -646,6 +658,7 @@ def kaka2user(kaka):
|
||||
|
||||
def delete_cookie(environ, name):
|
||||
kaka = environ.get("HTTP_COOKIE", '')
|
||||
logger.debug("delete KAKA: %s" % kaka)
|
||||
if kaka:
|
||||
cookie_obj = SimpleCookie(kaka)
|
||||
morsel = cookie_obj.get(name, None)
|
||||
@@ -739,7 +752,7 @@ def application(environ, start_response):
|
||||
try:
|
||||
query = parse_qs(environ["QUERY_STRING"])
|
||||
logger.debug("QUERY: %s" % query)
|
||||
user = IDP.authn[query["id"][0]]
|
||||
user = IDP.uid2user[query["id"][0]]
|
||||
except KeyError:
|
||||
user = None
|
||||
|
||||
|
@@ -145,8 +145,10 @@ def logout(environ, start_response, user):
|
||||
# What if more than one
|
||||
_dict = client.saml_client.global_logout(subject_id)
|
||||
logger.info("[logout] global_logout > %s" % (_dict,))
|
||||
rem = environ['repoze.who.plugins'][client.rememberer_name]
|
||||
rem.forget(environ, subject_id)
|
||||
|
||||
for key, item in _dict.item():
|
||||
for key, item in _dict.items():
|
||||
if isinstance(item, tuple):
|
||||
binding, htargs = item
|
||||
else: # result from logout, should be OK
|
||||
@@ -200,13 +202,15 @@ def application(environ, start_response):
|
||||
request is done
|
||||
:return: The response as a list of lines
|
||||
"""
|
||||
path = environ.get('PATH_INFO', '').lstrip('/')
|
||||
logger.info("<application> PATH: %s" % path)
|
||||
|
||||
user = environ.get("REMOTE_USER", "")
|
||||
if not user:
|
||||
user = environ.get("repoze.who.identity", "")
|
||||
|
||||
path = environ.get('PATH_INFO', '').lstrip('/')
|
||||
logger.info("<application> PATH: %s" % path)
|
||||
logger.info("logger name: %s" % logger.name)
|
||||
logger.info("repoze.who.identity: '%s'" % user)
|
||||
else:
|
||||
logger.info("REMOTE_USER: '%s'" % user)
|
||||
#logger.info(logging.Logger.manager.loggerDict)
|
||||
for regex, callback in urls:
|
||||
if user:
|
||||
|
@@ -5,14 +5,14 @@ BASE= "http://localhost:8087"
|
||||
#BASE= "http://lingon.catalogix.se:8087"
|
||||
|
||||
CONFIG = {
|
||||
"entityid" : "%s/sp.xml" % BASE,
|
||||
"entityid": "%s/sp.xml" % BASE,
|
||||
"description": "My SP",
|
||||
"service": {
|
||||
"sp":{
|
||||
"sp": {
|
||||
"name" : "Rolands SP",
|
||||
"endpoints":{
|
||||
"endpoints": {
|
||||
"assertion_consumer_service": [BASE],
|
||||
"single_logout_service" : [(BASE+"/slo",
|
||||
"single_logout_service": [(BASE + "/slo",
|
||||
BINDING_HTTP_REDIRECT)],
|
||||
},
|
||||
"required_attributes": ["surname", "givenname",
|
||||
@@ -20,18 +20,16 @@ CONFIG = {
|
||||
"optional_attributes": ["title"],
|
||||
}
|
||||
},
|
||||
"debug" : 1,
|
||||
"key_file" : "pki/mykey.pem",
|
||||
"cert_file" : "pki/mycert.pem",
|
||||
"attribute_map_dir" : "./attributemaps",
|
||||
"metadata" : {
|
||||
"local": ["../idp2/idp.xml"],
|
||||
},
|
||||
"debug": 1,
|
||||
"key_file": "pki/mykey.pem",
|
||||
"cert_file": "pki/mycert.pem",
|
||||
"attribute_map_dir": "./attributemaps",
|
||||
"metadata": {"local": ["../idp2/idp.xml"]},
|
||||
# -- below used by make_metadata --
|
||||
"organization": {
|
||||
"name": "Exempel AB",
|
||||
"display_name": [("Exempel AB","se"),("Example Co.","en")],
|
||||
"url":"http://www.example.com/roland",
|
||||
"display_name": [("Exempel AB", "se"), ("Example Co.", "en")],
|
||||
"url": "http://www.example.com/roland",
|
||||
},
|
||||
"contact_person": [{
|
||||
"given_name":"John",
|
||||
@@ -47,7 +45,7 @@ CONFIG = {
|
||||
"filename": "sp.log",
|
||||
"maxBytes": 100000,
|
||||
"backupCount": 5,
|
||||
},
|
||||
},
|
||||
"loglevel": "debug",
|
||||
}
|
||||
}
|
||||
|
@@ -14,9 +14,8 @@ reissue_time = 3000
|
||||
[plugin:saml2auth]
|
||||
use = s2repoze.plugins.sp:make_plugin
|
||||
saml_conf = sp_conf
|
||||
rememberer_name = auth_tkt
|
||||
remember_name = auth_tkt
|
||||
sid_store = outstanding
|
||||
identity_cache = identities
|
||||
|
||||
[general]
|
||||
request_classifier = s2repoze.plugins.challenge_decider:my_request_classifier
|
||||
|
@@ -42,6 +42,7 @@ from saml2 import ecp, BINDING_HTTP_REDIRECT
|
||||
from saml2 import BINDING_HTTP_POST
|
||||
|
||||
from saml2.client import Saml2Client
|
||||
from saml2.ident import code, decode
|
||||
from saml2.s_utils import sid
|
||||
from saml2.config import config_factory
|
||||
from saml2.profile import paos
|
||||
@@ -53,16 +54,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
PAOS_HEADER_INFO = 'ver="%s";"%s"' % (paos.NAMESPACE, ECP_SERVICE)
|
||||
|
||||
|
||||
def construct_came_from(environ):
|
||||
""" The URL that the user used when the process where interupted
|
||||
for single-sign-on processing. """
|
||||
|
||||
came_from = environ.get("PATH_INFO")
|
||||
qstr = environ.get("QUERY_STRING","")
|
||||
qstr = environ.get("QUERY_STRING", "")
|
||||
if qstr:
|
||||
came_from += '?' + qstr
|
||||
return came_from
|
||||
|
||||
|
||||
# FormPluginBase defines the methods remember and forget
|
||||
def cgi_field_storage_to_dict(field_storage):
|
||||
"""Get a plain dictionary, rather than the '.value' system used by the
|
||||
@@ -71,16 +74,15 @@ def cgi_field_storage_to_dict(field_storage):
|
||||
params = {}
|
||||
for key in field_storage.keys():
|
||||
try:
|
||||
params[ key ] = field_storage[ key ].value
|
||||
params[key] = field_storage[key].value
|
||||
except AttributeError:
|
||||
if isinstance(field_storage[ key ], basestring):
|
||||
if isinstance(field_storage[key], basestring):
|
||||
params[key] = field_storage[key]
|
||||
|
||||
return params
|
||||
|
||||
def get_body(environ):
|
||||
body = ""
|
||||
|
||||
def get_body(environ):
|
||||
length = int(environ["CONTENT_LENGTH"])
|
||||
try:
|
||||
body = environ["wsgi.input"].read(length)
|
||||
@@ -95,11 +97,13 @@ def get_body(environ):
|
||||
|
||||
return body
|
||||
|
||||
|
||||
def exception_trace(tag, exc, log):
|
||||
message = traceback.format_exception(*sys.exc_info())
|
||||
log.error("[%s] ExcList: %s" % (tag, "".join(message),))
|
||||
log.error("[%s] Exception: %s" % (tag, exc))
|
||||
|
||||
|
||||
class ECP_response(object):
|
||||
code = 200
|
||||
title = 'OK'
|
||||
@@ -114,13 +118,12 @@ class ECP_response(object):
|
||||
return [self.content]
|
||||
|
||||
|
||||
|
||||
class SAML2Plugin(FormPluginBase):
|
||||
|
||||
implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider)
|
||||
|
||||
def __init__(self, rememberer_name, config, saml_client,
|
||||
wayf, cache, debug, sid_store=None, discovery=""):
|
||||
def __init__(self, rememberer_name, config, saml_client, wayf, cache,
|
||||
debug, sid_store=None, discovery=""):
|
||||
FormPluginBase.__init__(self)
|
||||
|
||||
self.rememberer_name = rememberer_name
|
||||
@@ -148,8 +151,6 @@ class SAML2Plugin(FormPluginBase):
|
||||
:param environ: A dictionary with environment variables
|
||||
"""
|
||||
|
||||
post = {}
|
||||
|
||||
post_env = environ.copy()
|
||||
post_env['QUERY_STRING'] = ''
|
||||
|
||||
@@ -173,8 +174,8 @@ class SAML2Plugin(FormPluginBase):
|
||||
sid_ = sid()
|
||||
self.outstanding_queries[sid_] = came_from
|
||||
logger.info("Redirect to WAYF function: %s" % self.wayf)
|
||||
return -1, HTTPSeeOther(headers = [('Location',
|
||||
"%s?%s" % (self.wayf, sid_))])
|
||||
return -1, HTTPSeeOther(headers=[('Location',
|
||||
"%s?%s" % (self.wayf, sid_))])
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def _pick_idp(self, environ, came_from):
|
||||
@@ -189,6 +190,8 @@ class SAML2Plugin(FormPluginBase):
|
||||
# 'PAOS' : 'ver="%s";"%s"' % (paos.NAMESPACE, SERVICE)
|
||||
# }
|
||||
|
||||
_cli = self.saml_client
|
||||
|
||||
logger.info("[_pick_idp] %s" % environ)
|
||||
if "HTTP_PAOS" in environ:
|
||||
if environ["HTTP_PAOS"] == PAOS_HEADER_INFO:
|
||||
@@ -200,28 +203,26 @@ class SAML2Plugin(FormPluginBase):
|
||||
logger.info("- ECP client detected -")
|
||||
|
||||
_relay_state = construct_came_from(environ)
|
||||
_entityid = self.saml_client.config.ecp_endpoint(
|
||||
environ["REMOTE_ADDR"])
|
||||
_entityid = _cli.config.ecp_endpoint(environ["REMOTE_ADDR"])
|
||||
|
||||
if not _entityid:
|
||||
return -1, HTTPInternalServerError(
|
||||
detail="No IdP to talk to"
|
||||
)
|
||||
detail="No IdP to talk to")
|
||||
logger.info("IdP to talk to: %s" % _entityid)
|
||||
return ecp.ecp_auth_request(self.saml_client, _entityid,
|
||||
return ecp.ecp_auth_request(_cli, _entityid,
|
||||
_relay_state)
|
||||
else:
|
||||
return -1, HTTPInternalServerError(
|
||||
detail='Faulty Accept header')
|
||||
detail='Faulty Accept header')
|
||||
else:
|
||||
return -1, HTTPInternalServerError(
|
||||
detail='unknown ECP version')
|
||||
|
||||
detail='unknown ECP version')
|
||||
|
||||
idps = self.metadata.with_descriptor("idpsso")
|
||||
|
||||
logger.info("IdP URL: %s" % idps)
|
||||
|
||||
if len( idps ) == 1:
|
||||
if len(idps) == 1:
|
||||
# idps is a dictionary
|
||||
idp_entity_id = idps.keys()[0]
|
||||
elif not len(idps):
|
||||
@@ -229,16 +230,17 @@ class SAML2Plugin(FormPluginBase):
|
||||
else:
|
||||
idp_entity_id = ""
|
||||
logger.info("ENVIRON: %s" % environ)
|
||||
query = environ.get('s2repoze.body','')
|
||||
query = environ.get('s2repoze.body', '')
|
||||
if not query:
|
||||
query = environ.get("QUERY_STRING","")
|
||||
query = environ.get("QUERY_STRING", "")
|
||||
|
||||
logger.info("<_pick_idp> query: %s" % query)
|
||||
|
||||
if self.wayf:
|
||||
if query:
|
||||
try:
|
||||
wayf_selected = dict(parse_qs(query))["wayf_selected"][0]
|
||||
wayf_selected = dict(parse_qs(query))[
|
||||
"wayf_selected"][0]
|
||||
except KeyError:
|
||||
return self._wayf_redirect(came_from)
|
||||
idp_entity_id = wayf_selected
|
||||
@@ -246,16 +248,16 @@ class SAML2Plugin(FormPluginBase):
|
||||
return self._wayf_redirect(came_from)
|
||||
elif self.discosrv:
|
||||
if query:
|
||||
idp_entity_id = self.saml_client.parse_discovery_service_response(
|
||||
query=environ.get("QUERY_STRING"))
|
||||
idp_entity_id = _cli.parse_discovery_service_response(
|
||||
query=environ.get("QUERY_STRING"))
|
||||
else:
|
||||
sid_ = sid()
|
||||
self.outstanding_queries[sid_] = came_from
|
||||
logger.info("Redirect to Discovery Service function")
|
||||
eid = self.saml_client.config.entity_id
|
||||
loc = self.saml_client.create_discovery_service_request(
|
||||
self.discosrv, eid)
|
||||
return -1, HTTPSeeOther(headers = [('Location',loc)])
|
||||
eid = _cli.config.entity_id
|
||||
loc = _cli.create_discovery_service_request(self.discosrv,
|
||||
eid)
|
||||
return -1, HTTPSeeOther(headers=[('Location', loc)])
|
||||
else:
|
||||
return -1, HTTPNotImplemented(detail='No WAYF or DJ present!')
|
||||
|
||||
@@ -266,8 +268,10 @@ class SAML2Plugin(FormPluginBase):
|
||||
#noinspection PyUnusedLocal
|
||||
def challenge(self, environ, _status, _app_headers, _forget_headers):
|
||||
|
||||
# this challenge consist in login out
|
||||
if environ.has_key('rwpc.logout'):
|
||||
_cli = self.saml_client
|
||||
|
||||
# this challenge consist in logging out
|
||||
if 'rwpc.logout' in environ:
|
||||
# ignore right now?
|
||||
pass
|
||||
|
||||
@@ -283,7 +287,7 @@ class SAML2Plugin(FormPluginBase):
|
||||
vorg_name = environ["myapp.vo"]
|
||||
except KeyError:
|
||||
try:
|
||||
vorg_name = self.saml_client.vorg._name
|
||||
vorg_name = _cli.vorg._name
|
||||
except AttributeError:
|
||||
vorg_name = ""
|
||||
|
||||
@@ -304,7 +308,6 @@ class SAML2Plugin(FormPluginBase):
|
||||
entity_id = response
|
||||
logger.info("[sp.challenge] entity_id: %s" % entity_id)
|
||||
# Do the AuthnRequest
|
||||
_cli = self.saml_client
|
||||
_binding = BINDING_HTTP_REDIRECT
|
||||
try:
|
||||
srvs = _cli.metadata.single_sign_on_service(entity_id, _binding)
|
||||
@@ -319,7 +322,8 @@ class SAML2Plugin(FormPluginBase):
|
||||
logger.debug("ht_args: %s" % ht_args)
|
||||
except Exception, exc:
|
||||
logger.exception(exc)
|
||||
raise Exception("Failed to construct the AuthnRequest: %s" % exc)
|
||||
raise Exception(
|
||||
"Failed to construct the AuthnRequest: %s" % exc)
|
||||
|
||||
# remember the request
|
||||
self.outstanding_queries[_sid] = came_from
|
||||
@@ -327,14 +331,15 @@ class SAML2Plugin(FormPluginBase):
|
||||
if not ht_args["data"] and ht_args["headers"][0][0] == "Location":
|
||||
logger.debug('redirect to: %s' % ht_args["headers"][0][1])
|
||||
return HTTPSeeOther(headers=ht_args["headers"])
|
||||
else :
|
||||
else:
|
||||
return ht_args["data"]
|
||||
|
||||
def _construct_identity(self, session_info):
|
||||
cni = code(session_info["name_id"])
|
||||
identity = {
|
||||
"login": session_info["name_id"],
|
||||
"login": cni,
|
||||
"password": "",
|
||||
'repoze.who.userid': session_info["name_id"],
|
||||
'repoze.who.userid': cni,
|
||||
"user": session_info["ava"],
|
||||
}
|
||||
logger.debug("Identity: %s" % identity)
|
||||
@@ -349,9 +354,8 @@ class SAML2Plugin(FormPluginBase):
|
||||
# Evaluate the response, returns a AuthnResponse instance
|
||||
try:
|
||||
authresp = self.saml_client.parse_authn_request_response(
|
||||
post["SAMLResponse"],
|
||||
BINDING_HTTP_POST,
|
||||
self.outstanding_queries)
|
||||
post["SAMLResponse"], BINDING_HTTP_POST,
|
||||
self.outstanding_queries)
|
||||
except Exception, excp:
|
||||
logger.exception("Exception: %s" % (excp,))
|
||||
raise
|
||||
@@ -413,7 +417,7 @@ class SAML2Plugin(FormPluginBase):
|
||||
pass
|
||||
|
||||
try:
|
||||
if not post.has_key("SAMLResponse"):
|
||||
if "SAMLResponse" not in post:
|
||||
logger.info("[sp.identify] --- NOT SAMLResponse ---")
|
||||
# Not for me, put the post back where next in line can
|
||||
# find it
|
||||
@@ -424,8 +428,8 @@ class SAML2Plugin(FormPluginBase):
|
||||
# check for SAML2 authN response
|
||||
#if self.debug:
|
||||
try:
|
||||
session_info = self._eval_authn_response(environ,
|
||||
cgi_field_storage_to_dict(post))
|
||||
session_info = self._eval_authn_response(
|
||||
environ, cgi_field_storage_to_dict(post))
|
||||
except Exception:
|
||||
return None
|
||||
except TypeError, exc:
|
||||
@@ -445,37 +449,28 @@ class SAML2Plugin(FormPluginBase):
|
||||
|
||||
if session_info:
|
||||
environ["s2repoze.sessioninfo"] = session_info
|
||||
name_id = session_info["name_id"]
|
||||
# contruct and return the identity
|
||||
identity = {
|
||||
"login": name_id,
|
||||
"password": "",
|
||||
'repoze.who.userid': name_id,
|
||||
"user": self.saml_client.users.get_identity(name_id)[0],
|
||||
}
|
||||
logger.info("[sp.identify] IDENTITY: %s" % (identity,))
|
||||
return identity
|
||||
return self._construct_identity(session_info)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# IMetadataProvider
|
||||
def add_metadata(self, environ, identity):
|
||||
""" Add information to the knowledge I have about the user """
|
||||
subject_id = identity['repoze.who.userid']
|
||||
#logger = environ.get('repoze.who.logger','')
|
||||
name_id = identity['repoze.who.userid']
|
||||
if isinstance(name_id, basestring):
|
||||
name_id = decode(name_id)
|
||||
|
||||
_cli = self.saml_client
|
||||
logger.debug("[add_metadata] for %s" % subject_id)
|
||||
logger.debug("[add_metadata] for %s" % name_id)
|
||||
try:
|
||||
logger.debug("Issuers: %s" % _cli.users.sources(subject_id))
|
||||
logger.debug("Issuers: %s" % _cli.users.sources(name_id))
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if "user" not in identity:
|
||||
identity["user"] = {}
|
||||
try:
|
||||
(ava, _) = _cli.users.get_identity(subject_id)
|
||||
(ava, _) = _cli.users.get_identity(name_id)
|
||||
#now = time.gmtime()
|
||||
logger.debug("[add_metadata] adds: %s" % ava)
|
||||
identity["user"].update(ava)
|
||||
@@ -486,9 +481,9 @@ class SAML2Plugin(FormPluginBase):
|
||||
# is this a Virtual Organization situation
|
||||
for vo in _cli.vorg.values():
|
||||
try:
|
||||
if vo.do_aggregation(subject_id):
|
||||
if vo.do_aggregation(name_id):
|
||||
# Get the extended identity
|
||||
identity["user"] = _cli.users.get_identity(subject_id)[0]
|
||||
identity["user"] = _cli.users.get_identity(name_id)[0]
|
||||
# Only do this once, mark that the identity has been
|
||||
# expanded
|
||||
identity["pysaml2_vo_expanded"] = 1
|
||||
@@ -505,7 +500,7 @@ class SAML2Plugin(FormPluginBase):
|
||||
# used 2 times : one to get the ticket, the other to validate it
|
||||
def _service_url(self, environ, qstr=None):
|
||||
if qstr is not None:
|
||||
url = construct_url(environ, querystring = qstr)
|
||||
url = construct_url(environ, querystring=qstr)
|
||||
else:
|
||||
url = construct_url(environ)
|
||||
return url
|
||||
@@ -519,32 +514,29 @@ class SAML2Plugin(FormPluginBase):
|
||||
return None
|
||||
|
||||
|
||||
def make_plugin(rememberer_name=None, # plugin for remember
|
||||
cache= "", # cache
|
||||
# Which virtual organization to support
|
||||
virtual_organization="",
|
||||
saml_conf="",
|
||||
wayf="",
|
||||
sid_store="",
|
||||
identity_cache="",
|
||||
discovery="",
|
||||
):
|
||||
def make_plugin(remember_name=None, # plugin for remember
|
||||
cache="", # cache
|
||||
# Which virtual organization to support
|
||||
virtual_organization="",
|
||||
saml_conf="",
|
||||
wayf="",
|
||||
sid_store="",
|
||||
identity_cache="",
|
||||
discovery="",
|
||||
):
|
||||
|
||||
if saml_conf is "":
|
||||
raise ValueError(
|
||||
'must include saml_conf in configuration')
|
||||
|
||||
if rememberer_name is None:
|
||||
raise ValueError(
|
||||
'must include rememberer_name in configuration')
|
||||
if remember_name is None:
|
||||
raise ValueError('must include remember_name in configuration')
|
||||
|
||||
conf = config_factory("sp", saml_conf)
|
||||
|
||||
scl = Saml2Client(config=conf, identity_cache=identity_cache,
|
||||
virtual_organization=virtual_organization)
|
||||
virtual_organization=virtual_organization)
|
||||
|
||||
plugin = SAML2Plugin(rememberer_name, conf, scl, wayf, cache, sid_store,
|
||||
plugin = SAML2Plugin(remember_name, conf, scl, wayf, cache, sid_store,
|
||||
discovery)
|
||||
return plugin
|
||||
|
||||
|
||||
|
@@ -18,6 +18,7 @@
|
||||
"""Contains classes and functions that a SAML2.0 Service Provider (SP) may use
|
||||
to conclude its tasks.
|
||||
"""
|
||||
from saml2.ident import decode
|
||||
from saml2.httpbase import HTTPError
|
||||
from saml2.s_utils import sid
|
||||
import saml2
|
||||
@@ -99,11 +100,14 @@ class Saml2Client(Base):
|
||||
conversation.
|
||||
"""
|
||||
|
||||
if isinstance(name_id, basestring):
|
||||
name_id = decode(name_id)
|
||||
|
||||
logger.info("logout request for: %s" % name_id)
|
||||
|
||||
# find out which IdPs/AAs I should notify
|
||||
entity_ids = self.users.issuers_of_info(name_id)
|
||||
|
||||
self.users.remove_person(name_id)
|
||||
return self.do_logout(name_id, entity_ids, reason, expire, sign)
|
||||
|
||||
def do_logout(self, name_id, entity_ids, reason, expire, sign=None):
|
||||
|
@@ -204,12 +204,12 @@ class Entity(HTTPBase):
|
||||
|
||||
raise Exception("Unkown entity or unsupported bindings")
|
||||
|
||||
def message_args(self, id=0):
|
||||
if not id:
|
||||
id = sid(self.seed)
|
||||
def message_args(self, seid=0):
|
||||
if not seid:
|
||||
seid = sid(self.seed)
|
||||
|
||||
return {"id":id, "version":VERSION,
|
||||
"issue_instant":instant(), "issuer":self._issuer()}
|
||||
return {"id": seid, "version": VERSION,
|
||||
"issue_instant": instant(), "issuer": self._issuer()}
|
||||
|
||||
def response_args(self, message, bindings=None, descr_type=""):
|
||||
info = {"in_response_to": message.id}
|
||||
@@ -278,10 +278,9 @@ class Entity(HTTPBase):
|
||||
:param text: The SOAP message
|
||||
:return: A dictionary with two keys "body" and "header"
|
||||
"""
|
||||
return class_instances_from_soap_enveloped_saml_thingies(text,
|
||||
[paos,
|
||||
ecp,
|
||||
samlp])
|
||||
return class_instances_from_soap_enveloped_saml_thingies(text, [paos,
|
||||
ecp,
|
||||
samlp])
|
||||
|
||||
def unpack_soap_message(self, text):
|
||||
"""
|
||||
@@ -367,8 +366,8 @@ class Entity(HTTPBase):
|
||||
_issuer = self._issuer(issuer)
|
||||
|
||||
response = response_factory(issuer=_issuer,
|
||||
in_response_to = in_response_to,
|
||||
status = status)
|
||||
in_response_to=in_response_to,
|
||||
status=status)
|
||||
|
||||
if consumer_url:
|
||||
response.destination = consumer_url
|
||||
@@ -377,7 +376,7 @@ class Entity(HTTPBase):
|
||||
setattr(response, key, val)
|
||||
|
||||
if sign:
|
||||
self.sign(response,to_sign=to_sign)
|
||||
self.sign(response, to_sign=to_sign)
|
||||
elif to_sign:
|
||||
return signed_instance_factory(response, self.sec, to_sign)
|
||||
else:
|
||||
@@ -689,8 +688,8 @@ class Entity(HTTPBase):
|
||||
if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]:
|
||||
try:
|
||||
# expected return address
|
||||
kwargs["return_addr"] = self.config.endpoint(service,
|
||||
binding=binding)[0]
|
||||
kwargs["return_addr"] = self.config.endpoint(
|
||||
service, binding=binding)[0]
|
||||
except Exception:
|
||||
logger.info("Not supposed to handle this!")
|
||||
return None
|
||||
|
@@ -19,8 +19,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
ATTRS = {"version":None,
|
||||
"name":"",
|
||||
ATTRS = {"version": None,
|
||||
"name": "",
|
||||
"value": None,
|
||||
"port": None,
|
||||
"port_specified": False,
|
||||
@@ -43,19 +43,22 @@ PAIRS = {
|
||||
"path": "path_specified"
|
||||
}
|
||||
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HTTPError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _since_epoch(cdate):
|
||||
"""
|
||||
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
|
||||
:return: UTC time
|
||||
"""
|
||||
|
||||
if len(cdate) < 29: # somethings broken
|
||||
if len(cdate) < 29: # somethings broken
|
||||
if len(cdate) < 5:
|
||||
return utc_now()
|
||||
|
||||
@@ -67,16 +70,19 @@ def _since_epoch(cdate):
|
||||
#return int(time.mktime(t))
|
||||
return calendar.timegm(t)
|
||||
|
||||
|
||||
def set_list2dict(sl):
|
||||
return dict(sl)
|
||||
|
||||
|
||||
def dict2set_list(dic):
|
||||
return [(k,v) for k,v in dic.items()]
|
||||
return [(k, v) for k, v in dic.items()]
|
||||
|
||||
|
||||
class HTTPBase(object):
|
||||
def __init__(self, verify=True, ca_bundle=None, key_file=None,
|
||||
cert_file=None):
|
||||
self.request_args = {"allow_redirects": False,}
|
||||
self.request_args = {"allow_redirects": False}
|
||||
#self.cookies = {}
|
||||
self.cookiejar = cookielib.CookieJar()
|
||||
|
||||
@@ -98,6 +104,12 @@ class HTTPBase(object):
|
||||
:return:
|
||||
"""
|
||||
part = urlparse.urlparse(url)
|
||||
|
||||
#if part.port:
|
||||
# _domain = "%s:%s" % (part.hostname, part.port)
|
||||
#else:
|
||||
_domain = part.hostname
|
||||
|
||||
cookie_dict = {}
|
||||
now = utc_now()
|
||||
for _, a in list(self.cookiejar._cookies.items()):
|
||||
@@ -106,6 +118,8 @@ class HTTPBase(object):
|
||||
# print cookie
|
||||
if cookie.expires and cookie.expires <= now:
|
||||
continue
|
||||
if not re.search("%s$" % cookie.domain, _domain):
|
||||
continue
|
||||
if not re.match(cookie.path, part.path):
|
||||
continue
|
||||
|
||||
@@ -116,8 +130,13 @@ class HTTPBase(object):
|
||||
def set_cookie(self, kaka, request):
|
||||
"""Returns a cookielib.Cookie based on a set-cookie header line"""
|
||||
|
||||
# default rfc2109=False
|
||||
# max-age, httponly
|
||||
if not kaka:
|
||||
return
|
||||
|
||||
part = urlparse.urlparse(request.url)
|
||||
_domain = part.hostname
|
||||
logger.debug("%s: '%s'" % (_domain, kaka))
|
||||
|
||||
for cookie_name, morsel in kaka.items():
|
||||
std_attr = ATTRS.copy()
|
||||
std_attr["name"] = cookie_name
|
||||
@@ -133,9 +152,9 @@ class HTTPBase(object):
|
||||
if attr in ATTRS:
|
||||
if morsel[attr]:
|
||||
if attr == "expires":
|
||||
std_attr[attr]=_since_epoch(morsel[attr])
|
||||
std_attr[attr] = _since_epoch(morsel[attr])
|
||||
else:
|
||||
std_attr[attr]=morsel[attr]
|
||||
std_attr[attr] = morsel[attr]
|
||||
elif attr == "max-age":
|
||||
if morsel["max-age"]:
|
||||
std_attr["expires"] = _since_epoch(morsel["max-age"])
|
||||
@@ -144,8 +163,12 @@ class HTTPBase(object):
|
||||
if std_attr[att]:
|
||||
std_attr[item] = True
|
||||
|
||||
if std_attr["domain"] and std_attr["domain"].startswith("."):
|
||||
std_attr["domain_initial_dot"] = True
|
||||
if std_attr["domain"]:
|
||||
if std_attr["domain"].startswith("."):
|
||||
std_attr["domain_initial_dot"] = True
|
||||
else:
|
||||
std_attr["domain"] = _domain
|
||||
std_attr["domain_specified"] = True
|
||||
|
||||
if morsel["max-age"] is 0:
|
||||
try:
|
||||
@@ -154,6 +177,13 @@ class HTTPBase(object):
|
||||
name=std_attr["name"])
|
||||
except ValueError:
|
||||
pass
|
||||
elif morsel["expires"] < utc_now():
|
||||
try:
|
||||
self.cookiejar.clear(domain=std_attr["domain"],
|
||||
path=std_attr["path"],
|
||||
name=std_attr["name"])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
new_cookie = cookielib.Cookie(**std_attr)
|
||||
self.cookiejar.set_cookie(new_cookie)
|
||||
@@ -164,21 +194,22 @@ class HTTPBase(object):
|
||||
_kwargs.update(kwargs)
|
||||
|
||||
if self.cookiejar:
|
||||
_kwargs["cookies"] = self.cookies(url)
|
||||
_cd = self.cookies(url)
|
||||
if _cd:
|
||||
_kwargs["cookies"] = _cd
|
||||
logger.debug("Sent cookies: %s" % _kwargs["cookies"])
|
||||
|
||||
if self.user and self.passwd:
|
||||
_kwargs["auth"]= (self.user, self.passwd)
|
||||
_kwargs["auth"] = (self.user, self.passwd)
|
||||
|
||||
#logger.info("SENT COOKIEs: %s" % (_kwargs["cookies"],))
|
||||
try:
|
||||
r = requests.request(method, url, **_kwargs)
|
||||
except requests.ConnectionError, exc:
|
||||
raise ConnectionError("%s" % exc)
|
||||
|
||||
try:
|
||||
#logger.info("RECEIVED COOKIEs: %s" % (r.headers["set-cookie"],))
|
||||
self.set_cookie(SimpleCookie(r.headers["set-cookie"]), r)
|
||||
except AttributeError, err:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return r
|
||||
@@ -196,7 +227,7 @@ class HTTPBase(object):
|
||||
:return: dictionary
|
||||
"""
|
||||
if not isinstance(message, basestring):
|
||||
request = "%s" % (message,)
|
||||
message = "%s" % (message,)
|
||||
|
||||
return http_form_post_message(message, destination, relay_state, typ)
|
||||
|
||||
@@ -213,7 +244,7 @@ class HTTPBase(object):
|
||||
:return: dictionary
|
||||
"""
|
||||
if not isinstance(message, basestring):
|
||||
request = "%s" % (message,)
|
||||
message = "%s" % (message,)
|
||||
|
||||
return http_redirect_message(message, destination, relay_state, typ)
|
||||
|
||||
@@ -278,7 +309,7 @@ class HTTPBase(object):
|
||||
soap_message = _signed
|
||||
|
||||
return {"url": destination, "method": "POST",
|
||||
"data":soap_message, "headers":headers}
|
||||
"data": soap_message, "headers": headers}
|
||||
|
||||
def send_using_soap(self, request, destination, headers=None, sign=False):
|
||||
"""
|
||||
@@ -304,7 +335,7 @@ class HTTPBase(object):
|
||||
logger.info("SOAP response: %s" % response.text)
|
||||
return response
|
||||
else:
|
||||
raise HTTPError("%d:%s" % (response.status_code, response.error))
|
||||
raise HTTPError("%d:%s" % (response.status_code, response.content))
|
||||
|
||||
def add_credentials(self, user, passwd):
|
||||
self.user = user
|
||||
|
@@ -202,6 +202,7 @@ class MetaData(object):
|
||||
res[srv["binding"]].append(srv)
|
||||
except KeyError:
|
||||
res[srv["binding"]] = [srv]
|
||||
logger.debug("_service => %s" % res)
|
||||
return res
|
||||
|
||||
def _ext_service(self, entity_id, typ, service, binding):
|
||||
|
@@ -28,7 +28,13 @@ import saml2
|
||||
import base64
|
||||
import urllib
|
||||
from saml2.s_utils import deflate_and_base64_encode
|
||||
from saml2.s_utils import Unsupported
|
||||
import logging
|
||||
from saml2.sigver import RSA_SHA1
|
||||
from saml2.sigver import REQ_ORDER
|
||||
from saml2.sigver import RESP_ORDER
|
||||
from saml2.sigver import RSASigner
|
||||
from saml2.sigver import sha1_digest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +57,9 @@ FORM_SPEC = """<form method="post" action="%s">
|
||||
<input type="hidden" name="RelayState" value="%s" />
|
||||
</form>"""
|
||||
|
||||
def http_form_post_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
|
||||
def http_form_post_message(message, location, relay_state="",
|
||||
typ="SAMLRequest"):
|
||||
"""The HTTP POST binding defines a mechanism by which SAML protocol
|
||||
messages may be transmitted within the base64-encoded content of a
|
||||
HTML form control.
|
||||
@@ -81,19 +89,9 @@ def http_form_post_message(message, location, relay_state="", typ="SAMLRequest")
|
||||
|
||||
return {"headers": [("Content-type", "text/html")], "data": response}
|
||||
|
||||
##noinspection PyUnresolvedReferences
|
||||
#def http_post_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
# """
|
||||
#
|
||||
# :param message:
|
||||
# :param location:
|
||||
# :param relay_state:
|
||||
# :param typ:
|
||||
# :return:
|
||||
# """
|
||||
# return {"headers": [("Content-type", "text/xml")], "data": message}
|
||||
|
||||
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest",
|
||||
sigalg=None, key=None):
|
||||
"""The HTTP Redirect binding defines a mechanism by which SAML protocol
|
||||
messages can be transmitted within URL parameters.
|
||||
Messages are encoded for use with this binding using a URL encoding
|
||||
@@ -104,13 +102,21 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
:param message: The message
|
||||
:param location: Where the message should be posted to
|
||||
:param relay_state: for preserving and conveying state information
|
||||
:param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart
|
||||
:param sigalg: The signature algorithm to use.
|
||||
:param key: Key to use for signing
|
||||
:return: A tuple containing header information and a HTML message.
|
||||
"""
|
||||
|
||||
if not isinstance(message, basestring):
|
||||
message = "%s" % (message,)
|
||||
|
||||
_order = None
|
||||
if typ in ["SAMLRequest", "SAMLResponse"]:
|
||||
if typ == "SAMLRequest":
|
||||
_order = REQ_ORDER
|
||||
else:
|
||||
_order = RESP_ORDER
|
||||
args = {typ: deflate_and_base64_encode(message)}
|
||||
elif typ == "SAMLart":
|
||||
args = {typ: message}
|
||||
@@ -120,16 +126,35 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
||||
if relay_state:
|
||||
args["RelayState"] = relay_state
|
||||
|
||||
if sigalg:
|
||||
# sigalgs
|
||||
# http://www.w3.org/2000/09/xmldsig#dsa-sha1
|
||||
# http://www.w3.org/2000/09/xmldsig#rsa-sha1
|
||||
|
||||
args["SigAlg"] = sigalg
|
||||
|
||||
if sigalg == RSA_SHA1:
|
||||
signer = RSASigner(sha1_digest, "sha1")
|
||||
string = "&".join([urllib.urlencode({k: args[k]}) for k in _order])
|
||||
args["Signature"] = base64.b64encode(signer.sign(string, key))
|
||||
string = urllib.urlencode(args)
|
||||
else:
|
||||
raise Unsupported("Signing algorithm")
|
||||
else:
|
||||
string = urllib.urlencode(args)
|
||||
|
||||
glue_char = "&" if urlparse.urlparse(location).query else "?"
|
||||
login_url = glue_char.join([location, urllib.urlencode(args)])
|
||||
login_url = glue_char.join([location, string])
|
||||
headers = [('Location', login_url)]
|
||||
body = []
|
||||
|
||||
return {"headers":headers, "data":body}
|
||||
return {"headers": headers, "data": body}
|
||||
|
||||
|
||||
DUMMY_NAMESPACE = "http://example.org/"
|
||||
PREFIX = '<?xml version="1.0" encoding="UTF-8"?>'
|
||||
|
||||
|
||||
def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
|
||||
""" Returns a soap envelope containing a SAML request
|
||||
as a text string.
|
||||
@@ -170,21 +195,24 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
|
||||
cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1]
|
||||
_str = _str.replace(cut1, "")
|
||||
first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],))
|
||||
last = _str.find(">", first+14)
|
||||
cut2 = _str[first:last+1]
|
||||
last = _str.find(">", first + 14)
|
||||
cut2 = _str[first:last + 1]
|
||||
return _str.replace(cut2, thingy)
|
||||
else:
|
||||
thingy.become_child_element_of(body)
|
||||
return ElementTree.tostring(envelope, encoding="UTF-8")
|
||||
|
||||
|
||||
def http_soap_message(message):
|
||||
return {"headers": [("Content-type", "application/soap+xml")],
|
||||
"data": make_soap_enveloped_saml_thingy(message)}
|
||||
|
||||
|
||||
def http_paos(message, extra=None):
|
||||
return {"headers":[("Content-type", "application/soap+xml")],
|
||||
return {"headers": [("Content-type", "application/soap+xml")],
|
||||
"data": make_soap_enveloped_saml_thingy(message, extra)}
|
||||
|
||||
|
||||
def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
||||
"""Parses a SOAP enveloped SAML thing and returns header parts and body
|
||||
|
||||
@@ -205,7 +233,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
||||
body = saml2.create_class_from_element_tree(body_class, sub)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Wrong body type (%s) in SOAP envelope" % sub.tag)
|
||||
"Wrong body type (%s) in SOAP envelope" % sub.tag)
|
||||
elif part.tag == '{%s}Header' % NAMESPACE:
|
||||
if not header_class:
|
||||
raise Exception("Header where I didn't expect one")
|
||||
@@ -226,13 +254,15 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
||||
PACKING = {
|
||||
saml2.BINDING_HTTP_REDIRECT: http_redirect_message,
|
||||
saml2.BINDING_HTTP_POST: http_form_post_message,
|
||||
}
|
||||
}
|
||||
|
||||
def packager( identifier ):
|
||||
|
||||
def packager(identifier):
|
||||
try:
|
||||
return PACKING[identifier]
|
||||
except KeyError:
|
||||
raise Exception("Unkown binding type: %s" % identifier)
|
||||
|
||||
|
||||
def factory(binding, message, location, relay_state="", typ="SAMLRequest"):
|
||||
return PACKING[binding](message, location, relay_state, typ)
|
@@ -9,7 +9,7 @@ import sys
|
||||
import hmac
|
||||
|
||||
# from python 2.5
|
||||
if sys.version_info >= (2,5):
|
||||
if sys.version_info >= (2, 5):
|
||||
import hashlib
|
||||
else: # before python 2.5
|
||||
import sha
|
||||
@@ -27,36 +27,51 @@ import zlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SamlException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class RequestVersionTooLow(SamlException):
|
||||
pass
|
||||
|
||||
|
||||
class RequestVersionTooHigh(SamlException):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownPrincipal(SamlException):
|
||||
pass
|
||||
|
||||
class UnsupportedBinding(SamlException):
|
||||
|
||||
class Unsupported(SamlException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedBinding(Unsupported):
|
||||
pass
|
||||
|
||||
|
||||
class VersionMismatch(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Unknown(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OtherError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingValue(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PolicyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BadRequest(Exception):
|
||||
pass
|
||||
|
||||
@@ -73,28 +88,29 @@ EXCEPTION2STATUS = {
|
||||
Exception: samlp.STATUS_AUTHN_FAILED,
|
||||
}
|
||||
|
||||
GENERIC_DOMAINS = "aero", "asia", "biz", "cat", "com", "coop", \
|
||||
"edu", "gov", "info", "int", "jobs", "mil", "mobi", "museum", \
|
||||
"name", "net", "org", "pro", "tel", "travel"
|
||||
GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu",
|
||||
"gov", "info", "int", "jobs", "mil", "mobi", "museum",
|
||||
"name", "net", "org", "pro", "tel", "travel"]
|
||||
|
||||
def valid_email(emailaddress, domains = GENERIC_DOMAINS):
|
||||
|
||||
def valid_email(emailaddress, domains=GENERIC_DOMAINS):
|
||||
"""Checks for a syntactically valid email address."""
|
||||
|
||||
# Email address must be at least 6 characters in total.
|
||||
# Assuming noone may have addresses of the type a@com
|
||||
if len(emailaddress) < 6:
|
||||
return False # Address too short.
|
||||
return False # Address too short.
|
||||
|
||||
# Split up email address into parts.
|
||||
try:
|
||||
localpart, domainname = emailaddress.rsplit('@', 1)
|
||||
host, toplevel = domainname.rsplit('.', 1)
|
||||
except ValueError:
|
||||
return False # Address does not have enough parts.
|
||||
return False # Address does not have enough parts.
|
||||
|
||||
# Check for Country code or Generic Domain.
|
||||
if len(toplevel) != 2 and toplevel not in domains:
|
||||
return False # Not a domain name.
|
||||
return False # Not a domain name.
|
||||
|
||||
for i in '-_.%+.':
|
||||
localpart = localpart.replace(i, "")
|
||||
@@ -102,27 +118,30 @@ def valid_email(emailaddress, domains = GENERIC_DOMAINS):
|
||||
host = host.replace(i, "")
|
||||
|
||||
if localpart.isalnum() and host.isalnum():
|
||||
return True # Email address is fine.
|
||||
return True # Email address is fine.
|
||||
else:
|
||||
return False # Email address has funny characters.
|
||||
return False # Email address has funny characters.
|
||||
|
||||
def decode_base64_and_inflate( string ):
|
||||
|
||||
def decode_base64_and_inflate(string):
|
||||
""" base64 decodes and then inflates according to RFC1951
|
||||
|
||||
:param string: a deflated and encoded string
|
||||
:return: the string after decoding and inflating
|
||||
"""
|
||||
|
||||
return zlib.decompress( base64.b64decode( string ) , -15)
|
||||
return zlib.decompress(base64.b64decode(string), -15)
|
||||
|
||||
def deflate_and_base64_encode( string_val ):
|
||||
|
||||
def deflate_and_base64_encode(string_val):
|
||||
"""
|
||||
Deflates and the base64 encodes a string
|
||||
|
||||
:param string_val: The string to deflate and encode
|
||||
:return: The deflated and encoded string
|
||||
"""
|
||||
return base64.b64encode( zlib.compress( string_val )[2:-4] )
|
||||
return base64.b64encode(zlib.compress(string_val)[2:-4])
|
||||
|
||||
|
||||
def rndstr(size=16):
|
||||
"""
|
||||
@@ -134,9 +153,11 @@ def rndstr(size=16):
|
||||
_basech = string.ascii_letters + string.digits
|
||||
return "".join([random.choice(_basech) for _ in range(size)])
|
||||
|
||||
|
||||
def sid(seed=""):
|
||||
"""The hash of the server time + seed makes an unique SID for each session.
|
||||
128-bits long so it fulfills the SAML2 requirements which states 128-160 bits
|
||||
128-bits long so it fulfills the SAML2 requirements which states
|
||||
128-160 bits
|
||||
|
||||
:param seed: A seed string
|
||||
:return: The hex version of the digest, prefixed by 'id-' to make it
|
||||
@@ -146,7 +167,8 @@ def sid(seed=""):
|
||||
ident.update(repr(time.time()))
|
||||
if seed:
|
||||
ident.update(seed)
|
||||
return "id-"+ident.hexdigest()
|
||||
return "id-" + ident.hexdigest()
|
||||
|
||||
|
||||
def parse_attribute_map(filenames):
|
||||
"""
|
||||
@@ -168,6 +190,7 @@ def parse_attribute_map(filenames):
|
||||
|
||||
return forward, backward
|
||||
|
||||
|
||||
def identity_attribute(form, attribute, forward_map=None):
|
||||
if form == "friendly":
|
||||
if attribute.friendly_name:
|
||||
@@ -182,6 +205,7 @@ def identity_attribute(form, attribute, forward_map=None):
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def error_status_factory(info):
|
||||
if isinstance(info, Exception):
|
||||
try:
|
||||
@@ -194,39 +218,38 @@ def error_status_factory(info):
|
||||
status_code=samlp.StatusCode(
|
||||
value=samlp.STATUS_RESPONDER,
|
||||
status_code=samlp.StatusCode(
|
||||
value=exc_val)
|
||||
),
|
||||
)
|
||||
value=exc_val)))
|
||||
else:
|
||||
(errcode, text) = info
|
||||
status = samlp.Status(
|
||||
status_message=samlp.StatusMessage(text=text),
|
||||
status_code=samlp.StatusCode(
|
||||
value=samlp.STATUS_RESPONDER,
|
||||
status_code=samlp.StatusCode(value=errcode)
|
||||
),
|
||||
)
|
||||
status_code=samlp.StatusCode(value=errcode)))
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def success_status_factory():
|
||||
return samlp.Status(status_code=samlp.StatusCode(
|
||||
value=samlp.STATUS_SUCCESS))
|
||||
value=samlp.STATUS_SUCCESS))
|
||||
|
||||
|
||||
def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
|
||||
return samlp.Status(
|
||||
status_message=samlp.StatusMessage(text=message),
|
||||
status_code=samlp.StatusCode(
|
||||
value=fro,
|
||||
status_code=samlp.StatusCode(value=code)))
|
||||
status_code=samlp.StatusCode(value=fro,
|
||||
status_code=samlp.StatusCode(value=code)))
|
||||
|
||||
|
||||
def assertion_factory(**kwargs):
|
||||
assertion = saml.Assertion(version=VERSION, id=sid(),
|
||||
issue_instant=instant())
|
||||
issue_instant=instant())
|
||||
for key, val in kwargs.items():
|
||||
setattr(assertion, key, val)
|
||||
return assertion
|
||||
|
||||
|
||||
def _attrval(val, typ=""):
|
||||
if isinstance(val, list) or isinstance(val, set):
|
||||
attrval = [saml.AttributeValue(text=v) for v in val]
|
||||
@@ -246,6 +269,7 @@ def _attrval(val, typ=""):
|
||||
# xmlns:xs="http://www.w3.org/2001/XMLSchema"
|
||||
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
|
||||
|
||||
def do_ava(val, typ=""):
|
||||
if isinstance(val, basestring):
|
||||
ava = saml.AttributeValue()
|
||||
@@ -253,7 +277,7 @@ def do_ava(val, typ=""):
|
||||
attrval = [ava]
|
||||
elif isinstance(val, list):
|
||||
attrval = [do_ava(v)[0] for v in val]
|
||||
elif val or val == False:
|
||||
elif val or val is False:
|
||||
ava = saml.AttributeValue()
|
||||
ava.set_text(val)
|
||||
attrval = [ava]
|
||||
@@ -268,6 +292,7 @@ def do_ava(val, typ=""):
|
||||
|
||||
return attrval
|
||||
|
||||
|
||||
def do_attribute(val, typ, key):
|
||||
attr = saml.Attribute()
|
||||
attrval = do_ava(val, typ)
|
||||
@@ -276,7 +301,7 @@ def do_attribute(val, typ, key):
|
||||
|
||||
if isinstance(key, basestring):
|
||||
attr.name = key
|
||||
elif isinstance(key, tuple): # 3-tuple or 2-tuple
|
||||
elif isinstance(key, tuple): # 3-tuple or 2-tuple
|
||||
try:
|
||||
(name, nformat, friendly) = key
|
||||
except ValueError:
|
||||
@@ -290,6 +315,7 @@ def do_attribute(val, typ, key):
|
||||
attr.friendly_name = friendly
|
||||
return attr
|
||||
|
||||
|
||||
def do_attributes(identity):
|
||||
attrs = []
|
||||
if not identity:
|
||||
@@ -308,6 +334,7 @@ def do_attributes(identity):
|
||||
attrs.append(attr)
|
||||
return attrs
|
||||
|
||||
|
||||
def do_attribute_statement(identity):
|
||||
"""
|
||||
:param identity: A dictionary with fiendly names as keys
|
||||
@@ -315,12 +342,14 @@ def do_attribute_statement(identity):
|
||||
"""
|
||||
return saml.AttributeStatement(attribute=do_attributes(identity))
|
||||
|
||||
|
||||
def factory(klass, **kwargs):
|
||||
instance = klass()
|
||||
for key, val in kwargs.items():
|
||||
setattr(instance, key, val)
|
||||
return instance
|
||||
|
||||
|
||||
def signature(secret, parts):
|
||||
"""Generates a signature.
|
||||
"""
|
||||
@@ -334,6 +363,7 @@ def signature(secret, parts):
|
||||
|
||||
return csum.hexdigest()
|
||||
|
||||
|
||||
def verify_signature(secret, parts):
|
||||
""" Checks that the signature is correct """
|
||||
if signature(secret, parts[:-1]) == parts[-1]:
|
||||
@@ -344,9 +374,10 @@ def verify_signature(secret, parts):
|
||||
|
||||
FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#"
|
||||
|
||||
|
||||
def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
|
||||
"""
|
||||
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value ) '#'
|
||||
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value) '#'
|
||||
Allowed attributes:
|
||||
TS the login time stamp
|
||||
RP the relying party entityID
|
||||
|
@@ -49,7 +49,7 @@ from saml2.assertion import Policy
|
||||
from saml2.assertion import restriction_from_attribute_spec
|
||||
from saml2.assertion import filter_attribute_value_assertions
|
||||
|
||||
from saml2.ident import IdentDB
|
||||
from saml2.ident import IdentDB, code
|
||||
#from saml2.profile import paos
|
||||
from saml2.profile import ecp
|
||||
|
||||
@@ -70,6 +70,8 @@ class Server(Entity):
|
||||
self.ticket = {}
|
||||
self.authn = {}
|
||||
self.assertion = {}
|
||||
self.user2uid = {}
|
||||
self.uid2user = {}
|
||||
|
||||
def init_config(self, stype="idp"):
|
||||
""" Remaining init of the server configuration
|
||||
@@ -194,8 +196,8 @@ class Server(Entity):
|
||||
def store_assertion(self, assertion, to_sign):
|
||||
self.assertion[assertion.id] = (assertion, to_sign)
|
||||
|
||||
def get_assertion(self, id):
|
||||
return self.assertion[id]
|
||||
def get_assertion(self, cid):
|
||||
return self.assertion[cid]
|
||||
|
||||
def store_authn_statement(self, authn_statement, name_id):
|
||||
"""
|
||||
@@ -204,7 +206,8 @@ class Server(Entity):
|
||||
:param name_id:
|
||||
:return:
|
||||
"""
|
||||
nkey = sha1("%s" % name_id).hexdigest()
|
||||
logger.debug("store authn about: %s" % name_id)
|
||||
nkey = sha1(code(name_id)).hexdigest()
|
||||
logger.debug("Store authn_statement under key: %s" % nkey)
|
||||
try:
|
||||
self.authn[nkey].append(authn_statement)
|
||||
@@ -221,8 +224,14 @@ class Server(Entity):
|
||||
:return:
|
||||
"""
|
||||
result = []
|
||||
key = sha1("%s" % name_id).hexdigest()
|
||||
for statement in self.authn[key]:
|
||||
key = sha1(code(name_id)).hexdigest()
|
||||
try:
|
||||
statements = self.authn[key]
|
||||
except KeyError:
|
||||
logger.info("Unknown subject %s" % name_id)
|
||||
return []
|
||||
|
||||
for statement in statements:
|
||||
if session_index:
|
||||
if statement.session_index != session_index:
|
||||
continue
|
||||
@@ -234,7 +243,8 @@ class Server(Entity):
|
||||
return result
|
||||
|
||||
def remove_authn_statements(self, name_id):
|
||||
nkey = sha1("%s" % name_id).hexdigest()
|
||||
logger.debug("remove authn about: %s" % name_id)
|
||||
nkey = sha1(code(name_id)).hexdigest()
|
||||
|
||||
del self.authn[nkey]
|
||||
|
||||
|
@@ -21,11 +21,15 @@ Based on the use of xmlsec1 binaries and not the python xmlsec module.
|
||||
|
||||
import base64
|
||||
from binascii import hexlify
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import os
|
||||
import sys
|
||||
from time import mktime
|
||||
import urllib
|
||||
import M2Crypto
|
||||
from M2Crypto.X509 import load_cert_string
|
||||
from saml2.samlp import Response
|
||||
|
||||
import xmldsig as ds
|
||||
@@ -37,8 +41,11 @@ from saml2 import ExtensionElement
|
||||
from saml2 import VERSION
|
||||
|
||||
from saml2.s_utils import sid
|
||||
from saml2.s_utils import Unsupported
|
||||
|
||||
from saml2.time_util import instant
|
||||
from saml2.time_util import utc_now
|
||||
from saml2.time_util import str_to_time
|
||||
|
||||
from tempfile import NamedTemporaryFile
|
||||
from subprocess import Popen, PIPE
|
||||
@@ -47,6 +54,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature")
|
||||
|
||||
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
|
||||
|
||||
|
||||
def signed(item):
|
||||
if SIG in item.c_children.keys() and item.signature:
|
||||
return True
|
||||
@@ -62,6 +72,7 @@ def signed(item):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_xmlsec_binary(paths=None):
|
||||
"""
|
||||
Tries to find the xmlsec1 binary.
|
||||
@@ -75,7 +86,7 @@ def get_xmlsec_binary(paths=None):
|
||||
bin_name = "xmlsec1"
|
||||
elif os.name == "nt":
|
||||
bin_name = "xmlsec1.exe"
|
||||
else: # Default !?
|
||||
else: # Default !?
|
||||
bin_name = "xmlsec1"
|
||||
|
||||
if paths:
|
||||
@@ -109,47 +120,36 @@ ENC_KEY_CLASS = "EncryptedKey"
|
||||
|
||||
_TEST_ = True
|
||||
|
||||
|
||||
class SignatureError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class XmlsecError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingKey(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DecryptError(Exception):
|
||||
pass
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
#def make_signed_instance(klass, spec, seccont, base64encode=False):
|
||||
# """ Will only return signed instance if the signature
|
||||
# preamble is present
|
||||
#
|
||||
# :param klass: The type of class the instance should be
|
||||
# :param spec: The specification of attributes and children of the class
|
||||
# :param seccont: The security context (instance of SecurityContext)
|
||||
# :param base64encode: Whether the attribute values should be base64 encoded
|
||||
# :return: A signed (or possibly unsigned) instance of the class
|
||||
# """
|
||||
# if "signature" in spec:
|
||||
# signed_xml = seccont.sign_statement_using_xmlsec("%s" % instance,
|
||||
# class_name(instance), instance.id)
|
||||
# return create_class_from_xml_string(instance.__class__, signed_xml)
|
||||
# else:
|
||||
# return make_instance(klass, spec, base64encode)
|
||||
|
||||
def xmlsec_version(execname):
|
||||
com_list = [execname,"--version"]
|
||||
com_list = [execname, "--version"]
|
||||
pof = Popen(com_list, stderr=PIPE, stdout=PIPE)
|
||||
try:
|
||||
return pof.stdout.read().split(" ")[1]
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
||||
base64encode=False, elements_to_sign=None):
|
||||
base64encode=False, elements_to_sign=None):
|
||||
"""
|
||||
Creates a class instance with a specified value, the specified
|
||||
class instance may be a value on a property in a defined class instance.
|
||||
@@ -169,14 +169,14 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
||||
|
||||
if isinstance(val, dict):
|
||||
cinst = _instance(klass, val, seccont, base64encode=base64encode,
|
||||
elements_to_sign=elements_to_sign)
|
||||
elements_to_sign=elements_to_sign)
|
||||
else:
|
||||
try:
|
||||
cinst = klass().set_text(val)
|
||||
except ValueError:
|
||||
if not part:
|
||||
cis = [_make_vals(sval, klass, seccont, klass_inst, prop,
|
||||
True, base64encode, elements_to_sign) for sval in val]
|
||||
True, base64encode, elements_to_sign) for sval in val]
|
||||
setattr(klass_inst, prop, cis)
|
||||
else:
|
||||
raise
|
||||
@@ -188,6 +188,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
||||
cis = [cinst]
|
||||
setattr(klass_inst, prop, cis)
|
||||
|
||||
|
||||
def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
||||
instance = klass()
|
||||
|
||||
@@ -208,30 +209,31 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
||||
#print "## %s, %s" % (prop, klassdef)
|
||||
if prop in ava:
|
||||
#print "### %s" % ava[prop]
|
||||
if isinstance(klassdef, list): # means there can be a list of values
|
||||
if isinstance(klassdef, list):
|
||||
# means there can be a list of values
|
||||
_make_vals(ava[prop], klassdef[0], seccont, instance, prop,
|
||||
base64encode=base64encode,
|
||||
elements_to_sign=elements_to_sign)
|
||||
base64encode=base64encode,
|
||||
elements_to_sign=elements_to_sign)
|
||||
else:
|
||||
cis = _make_vals(ava[prop], klassdef, seccont, instance, prop,
|
||||
True, base64encode, elements_to_sign)
|
||||
True, base64encode, elements_to_sign)
|
||||
setattr(instance, prop, cis)
|
||||
|
||||
if "extension_elements" in ava:
|
||||
for item in ava["extension_elements"]:
|
||||
instance.extension_elements.append(
|
||||
ExtensionElement(item["tag"]).loadd(item))
|
||||
ExtensionElement(item["tag"]).loadd(item))
|
||||
|
||||
if "extension_attributes" in ava:
|
||||
for key, val in ava["extension_attributes"].items():
|
||||
instance.extension_attributes[key] = val
|
||||
|
||||
|
||||
|
||||
if "signature" in ava:
|
||||
elements_to_sign.append((class_name(instance), instance.id))
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
||||
"""
|
||||
|
||||
@@ -243,8 +245,8 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
||||
if elements_to_sign:
|
||||
signed_xml = "%s" % instance
|
||||
for (node_name, nodeid) in elements_to_sign:
|
||||
signed_xml = seccont.sign_statement_using_xmlsec(signed_xml,
|
||||
klass_namn=node_name, nodeid=nodeid)
|
||||
signed_xml = seccont.sign_statement_using_xmlsec(
|
||||
signed_xml, klass_namn=node_name, nodeid=nodeid)
|
||||
|
||||
#print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
#print "%s" % signed_xml
|
||||
@@ -255,6 +257,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_id():
|
||||
""" Create a string of 40 random characters from the set [a-p],
|
||||
can be used as a unique identifier of objects.
|
||||
@@ -266,6 +269,7 @@ def create_id():
|
||||
ret += chr(random.randint(0, 15) + ord('a'))
|
||||
return ret
|
||||
|
||||
|
||||
def make_temp(string, suffix="", decode=True):
|
||||
""" xmlsec needs files in some cases where only strings exist, hence the
|
||||
need for this function. It creates a temporary file with the
|
||||
@@ -288,10 +292,35 @@ def make_temp(string, suffix="", decode=True):
|
||||
ntf.seek(0)
|
||||
return ntf, ntf.name
|
||||
|
||||
|
||||
def split_len(seq, length):
|
||||
return [seq[i:i+length] for i in range(0, len(seq), length)]
|
||||
|
||||
def cert_from_key_info(key_info):
|
||||
return [seq[i:i + length] for i in range(0, len(seq), length)]
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
M2_TIME_FORMAT = "%b %d %H:%M:%S %Y"
|
||||
|
||||
|
||||
def to_time(_time):
|
||||
assert _time.endswith(" GMT")
|
||||
_time = _time[:-4]
|
||||
return mktime(str_to_time(_time, M2_TIME_FORMAT))
|
||||
|
||||
|
||||
def active_cert(key):
|
||||
cert_str = pem_format(key)
|
||||
certificate = load_cert_string(cert_str)
|
||||
not_before = to_time(str(certificate.get_not_before()))
|
||||
not_after = to_time(str(certificate.get_not_after()))
|
||||
try:
|
||||
assert not_before < utc_now()
|
||||
assert not_after > utc_now()
|
||||
return True
|
||||
except AssertionError:
|
||||
return False
|
||||
|
||||
|
||||
def cert_from_key_info(key_info, ignore_age=False):
|
||||
""" Get all X509 certs from a KeyInfo instance. Care is taken to make sure
|
||||
that the certs are continues sequences of bytes.
|
||||
|
||||
@@ -307,12 +336,16 @@ def cert_from_key_info(key_info):
|
||||
#print "X509Data",x509_data
|
||||
x509_certificate = x509_data.x509_certificate
|
||||
cert = x509_certificate.text.strip()
|
||||
cert = "\n".join(split_len("".join([
|
||||
s.strip() for s in cert.split()]),64))
|
||||
res.append(cert)
|
||||
cert = "\n".join(split_len("".join([s.strip() for s in
|
||||
cert.split()]), 64))
|
||||
if ignore_age or active_cert(cert):
|
||||
res.append(cert)
|
||||
else:
|
||||
logger.info("Inactive cert")
|
||||
return res
|
||||
|
||||
def cert_from_key_info_dict(key_info):
|
||||
|
||||
def cert_from_key_info_dict(key_info, ignore_age=False):
|
||||
""" Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure
|
||||
that the certs are continues sequences of bytes.
|
||||
|
||||
@@ -330,11 +363,15 @@ def cert_from_key_info_dict(key_info):
|
||||
for x509_data in key_info["x509_data"]:
|
||||
x509_certificate = x509_data["x509_certificate"]
|
||||
cert = x509_certificate["text"].strip()
|
||||
cert = "\n".join(split_len("".join([
|
||||
s.strip() for s in cert.split()]),64))
|
||||
res.append(cert)
|
||||
cert = "\n".join(split_len("".join([s.strip() for s in
|
||||
cert.split()]), 64))
|
||||
if ignore_age or active_cert(cert):
|
||||
res.append(cert)
|
||||
else:
|
||||
logger.info("Inactive cert")
|
||||
return res
|
||||
|
||||
|
||||
def cert_from_instance(instance):
|
||||
""" Find certificates that are part of an instance
|
||||
|
||||
@@ -343,32 +380,38 @@ def cert_from_instance(instance):
|
||||
"""
|
||||
if instance.signature:
|
||||
if instance.signature.key_info:
|
||||
return cert_from_key_info(instance.signature.key_info)
|
||||
return cert_from_key_info(instance.signature.key_info,
|
||||
ignore_age=True)
|
||||
return []
|
||||
|
||||
# =============================================================================
|
||||
from M2Crypto.__m2crypto import bn_to_mpi
|
||||
from M2Crypto.__m2crypto import hex_to_bn
|
||||
|
||||
|
||||
def intarr2long(arr):
|
||||
return long(''.join(["%02x" % byte for byte in arr]), 16)
|
||||
|
||||
|
||||
def dehexlify(bi):
|
||||
s = hexlify(bi)
|
||||
return [int(s[i]+s[i+1], 16) for i in range(0,len(s),2)]
|
||||
return [int(s[i] + s[i + 1], 16) for i in range(0, len(s), 2)]
|
||||
|
||||
|
||||
def long_to_mpi(num):
|
||||
"""Converts a python integer or long to OpenSSL MPInt used by M2Crypto.
|
||||
Borrowed from Snowball.Shared.Crypto"""
|
||||
h = hex(num)[2:] # strip leading 0x in string
|
||||
h = hex(num)[2:] # strip leading 0x in string
|
||||
if len(h) % 2 == 1:
|
||||
h = '0' + h # add leading 0 to get even number of hexdigits
|
||||
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
|
||||
h = '0' + h # add leading 0 to get even number of hexdigits
|
||||
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
|
||||
|
||||
|
||||
def base64_to_long(data):
|
||||
_d = base64.urlsafe_b64decode(data + '==')
|
||||
return intarr2long(dehexlify(_d))
|
||||
|
||||
|
||||
def key_from_key_value(key_info):
|
||||
res = []
|
||||
for value in key_info.key_value:
|
||||
@@ -376,10 +419,11 @@ def key_from_key_value(key_info):
|
||||
e = base64_to_long(value.rsa_key_value.exponent)
|
||||
m = base64_to_long(value.rsa_key_value.modulus)
|
||||
key = M2Crypto.RSA.new_pub_key((long_to_mpi(e),
|
||||
long_to_mpi(m)))
|
||||
long_to_mpi(m)))
|
||||
res.append(key)
|
||||
return res
|
||||
|
||||
|
||||
def key_from_key_value_dict(key_info):
|
||||
res = []
|
||||
if not "key_value" in key_info:
|
||||
@@ -396,10 +440,28 @@ def key_from_key_value_dict(key_info):
|
||||
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def rsa_load(filename):
|
||||
"""Read a PEM-encoded RSA key pair from a file."""
|
||||
return M2Crypto.RSA.load_key(filename, M2Crypto.util.no_passphrase_callback)
|
||||
|
||||
|
||||
def rsa_loads(key):
|
||||
"""Read a PEM-encoded RSA key pair from a string."""
|
||||
return M2Crypto.RSA.load_key_string(key,
|
||||
M2Crypto.util.no_passphrase_callback)
|
||||
|
||||
|
||||
def x509_rsa_loads(string):
|
||||
cert = M2Crypto.X509.load_cert_string(string)
|
||||
return cert.get_pubkey().get_rsa()
|
||||
|
||||
|
||||
def pem_format(key):
|
||||
return "\n".join(["-----BEGIN CERTIFICATE-----",
|
||||
key,"-----END CERTIFICATE-----"])
|
||||
key, "-----END CERTIFICATE-----"])
|
||||
|
||||
|
||||
def parse_xmlsec_output(output):
|
||||
""" Parse the output from xmlsec to try to find out if the
|
||||
command was successfull or not.
|
||||
@@ -416,12 +478,84 @@ def parse_xmlsec_output(output):
|
||||
|
||||
__DEBUG = 0
|
||||
|
||||
LOG_LINE = 60*"="+"\n%s\n"+60*"-"+"\n%s"+60*"="
|
||||
LOG_LINE_2 = 60*"="+"\n%s\n%s\n"+60*"-"+"\n%s"+60*"="
|
||||
|
||||
class BadSignature(Exception):
|
||||
"""The signature is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
def sha1_digest(msg):
|
||||
return hashlib.sha1(msg).digest()
|
||||
|
||||
|
||||
class Signer(object):
|
||||
"""Abstract base class for signing algorithms."""
|
||||
def sign(self, msg, key):
|
||||
"""Sign ``msg`` with ``key`` and return the signature."""
|
||||
raise NotImplementedError
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
"""Return True if ``sig`` is a valid signature for ``msg``."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RSASigner(Signer):
|
||||
def __init__(self, digest, algo):
|
||||
self.digest = digest
|
||||
self.algo = algo
|
||||
|
||||
def sign(self, msg, key):
|
||||
return key.sign(self.digest(msg), self.algo)
|
||||
|
||||
def verify(self, msg, sig, key):
|
||||
try:
|
||||
return key.verify(self.digest(msg), sig, self.algo)
|
||||
except M2Crypto.RSA.RSAError, e:
|
||||
raise BadSignature(e)
|
||||
|
||||
|
||||
REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"]
|
||||
RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"]
|
||||
|
||||
|
||||
def verify_redirect_signature(info, cert):
|
||||
"""
|
||||
|
||||
:param info: A dictionary as produced by parse_qs, means all values are
|
||||
lists.
|
||||
:param cert: A certificate to use when verifying the signature
|
||||
:return: True, if signature verified
|
||||
"""
|
||||
|
||||
if info["SigAlg"][0] == RSA_SHA1:
|
||||
if "SAMLRequest" in info:
|
||||
_order = REQ_ORDER
|
||||
elif "SAMLResponse" in info:
|
||||
_order = RESP_ORDER
|
||||
else:
|
||||
raise Unsupported(
|
||||
"Verifying signature on something that should not be signed")
|
||||
signer = RSASigner(sha1_digest, "sha1")
|
||||
args = info.copy()
|
||||
del args["Signature"] # everything but the signature
|
||||
string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order])
|
||||
_key = x509_rsa_loads(pem_format(cert))
|
||||
_sign = base64.b64decode(info["Signature"][0])
|
||||
try:
|
||||
signer.verify(string, _sign, _key)
|
||||
return True
|
||||
except BadSignature:
|
||||
return False
|
||||
else:
|
||||
raise Unsupported("Signature algorithm: %s" % info["SigAlg"])
|
||||
|
||||
LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||
LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||
|
||||
|
||||
def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
|
||||
node_name=NODE_NAME, debug=False, node_id=None,
|
||||
id_attr=""):
|
||||
node_name=NODE_NAME, debug=False, node_id=None,
|
||||
id_attr=""):
|
||||
""" Verifies the signature of a XML document.
|
||||
|
||||
:param enctext: The signed XML document
|
||||
@@ -481,6 +615,7 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_cert_from_file(cert_file, cert_type):
|
||||
""" Reads a certificate from a file. The assumption is that there is
|
||||
only one certificate in the file
|
||||
@@ -516,6 +651,7 @@ def read_cert_from_file(cert_file, cert_type):
|
||||
data = open(cert_file).read()
|
||||
return base64.b64encode(str(data))
|
||||
|
||||
|
||||
def security_context(conf, debug=None):
|
||||
""" Creates a security context based on the configuration
|
||||
|
||||
@@ -535,14 +671,15 @@ def security_context(conf, debug=None):
|
||||
_only_md = False
|
||||
|
||||
return SecurityContext(conf.xmlsec_binary, conf.key_file,
|
||||
cert_file=conf.cert_file, metadata=metadata,
|
||||
debug=debug, only_use_keys_in_metadata=_only_md)
|
||||
cert_file=conf.cert_file, metadata=metadata,
|
||||
debug=debug, only_use_keys_in_metadata=_only_md)
|
||||
|
||||
|
||||
class SecurityContext(object):
|
||||
def __init__(self, xmlsec_binary, key_file="", key_type= "pem",
|
||||
cert_file="", cert_type="pem", metadata=None,
|
||||
debug=False, template="", encrypt_key_type="des-192",
|
||||
only_use_keys_in_metadata=False):
|
||||
def __init__(self, xmlsec_binary, key_file="", key_type="pem",
|
||||
cert_file="", cert_type="pem", metadata=None,
|
||||
debug=False, template="", encrypt_key_type="des-192",
|
||||
only_use_keys_in_metadata=False):
|
||||
|
||||
self.xmlsec = xmlsec_binary
|
||||
|
||||
@@ -592,12 +729,9 @@ class SecurityContext(object):
|
||||
_, fil = make_temp("%s" % text, decode=False)
|
||||
ntf = NamedTemporaryFile()
|
||||
|
||||
com_list = [self.xmlsec, "--encrypt",
|
||||
"--pubkey-pem", recv_key,
|
||||
"--session-key", key_type,
|
||||
"--xml-data", fil,
|
||||
"--output", ntf.name,
|
||||
template]
|
||||
com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key,
|
||||
"--session-key", key_type, "--xml-data", fil,
|
||||
"--output", ntf.name, template]
|
||||
|
||||
logger.debug("Encryption command: %s" % " ".join(com_list))
|
||||
|
||||
@@ -625,11 +759,9 @@ class SecurityContext(object):
|
||||
_, fil = make_temp("%s" % enctext, decode=False)
|
||||
ntf = NamedTemporaryFile()
|
||||
|
||||
com_list = [self.xmlsec, "--decrypt",
|
||||
"--privkey-pem", self.key_file,
|
||||
"--output", ntf.name,
|
||||
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS,
|
||||
fil]
|
||||
com_list = [self.xmlsec, "--decrypt", "--privkey-pem",
|
||||
self.key_file, "--output", ntf.name,
|
||||
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, fil]
|
||||
|
||||
logger.debug("Decrypt command: %s" % " ".join(com_list))
|
||||
|
||||
@@ -646,9 +778,8 @@ class SecurityContext(object):
|
||||
ntf.seek(0)
|
||||
return ntf.read()
|
||||
|
||||
|
||||
def verify_signature(self, enctext, cert_file=None, cert_type="pem",
|
||||
node_name=NODE_NAME, node_id=None, id_attr=""):
|
||||
node_name=NODE_NAME, node_id=None, id_attr=""):
|
||||
""" Verifies the signature of a XML document.
|
||||
|
||||
:param enctext: The XML document as a string
|
||||
@@ -773,22 +904,22 @@ class SecurityContext(object):
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_authn_query(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "authn_query",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_logout_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "logout_request",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_logout_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "logout_response",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_attribute_query(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "attribute_query",
|
||||
must, origdoc)
|
||||
|
||||
@@ -799,31 +930,31 @@ class SecurityContext(object):
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_authz_decision_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"authz_decision_response", must,
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"name_id_mapping_request",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"name_id_mapping_response",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_artifact_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"artifact_request",
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_artifact_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"artifact_response",
|
||||
must, origdoc)
|
||||
@@ -835,19 +966,19 @@ class SecurityContext(object):
|
||||
must, origdoc)
|
||||
|
||||
def correctly_signed_manage_name_id_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"manage_name_id_response", must,
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_assertion_id_request(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml,
|
||||
"assertion_id_request", must,
|
||||
origdoc)
|
||||
|
||||
def correctly_signed_assertion_id_response(self, decoded_xml, must=False,
|
||||
origdoc=None):
|
||||
origdoc=None):
|
||||
return self.correctly_signed_message(decoded_xml, "assertion", must,
|
||||
origdoc)
|
||||
|
||||
@@ -882,18 +1013,16 @@ class SecurityContext(object):
|
||||
|
||||
try:
|
||||
self._check_signature(decoded_xml, assertion,
|
||||
class_name(assertion), origdoc)
|
||||
class_name(assertion), origdoc)
|
||||
except Exception, exc:
|
||||
logger.error("correctly_signed_response: %s" % exc)
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
|
||||
#--------------------------------------------------------------------------
|
||||
# SIGNATURE PART
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
def sign_statement_using_xmlsec(self, statement, klass_namn, key=None,
|
||||
key_file=None, nodeid=None, id_attr=""):
|
||||
"""Sign a SAML statement using xmlsec.
|
||||
@@ -917,7 +1046,6 @@ class SecurityContext(object):
|
||||
|
||||
_, fil = make_temp("%s" % statement, decode=False)
|
||||
|
||||
|
||||
ntf = NamedTemporaryFile()
|
||||
|
||||
com_list = [self.xmlsec, "--sign",
|
||||
@@ -975,10 +1103,8 @@ class SecurityContext(object):
|
||||
:return: The signed statement
|
||||
"""
|
||||
|
||||
return self.sign_statement_using_xmlsec(statement,
|
||||
class_name(samlp.AttributeQuery()),
|
||||
key, key_file, nodeid,
|
||||
id_attr=id_attr)
|
||||
return self.sign_statement_using_xmlsec(statement, class_name(
|
||||
samlp.AttributeQuery()), key, key_file, nodeid, id_attr=id_attr)
|
||||
|
||||
def multiple_signatures(self, statement, to_sign, key=None, key_file=None):
|
||||
"""
|
||||
@@ -991,15 +1117,15 @@ class SecurityContext(object):
|
||||
:param key_file: A file that contains the key to be used
|
||||
:return: A possibly multiple signed statement
|
||||
"""
|
||||
for (item, id, id_attr) in to_sign:
|
||||
if not id:
|
||||
for (item, sid, id_attr) in to_sign:
|
||||
if not sid:
|
||||
if not item.id:
|
||||
id = item.id = sid()
|
||||
sid = item.id = sid()
|
||||
else:
|
||||
id = item.id
|
||||
sid = item.id
|
||||
|
||||
if not item.signature:
|
||||
item.signature = pre_signature_part(id, self.cert_file)
|
||||
item.signature = pre_signature_part(sid, self.cert_file)
|
||||
|
||||
statement = self.sign_statement_using_xmlsec(statement,
|
||||
class_name(item),
|
||||
@@ -1024,32 +1150,30 @@ def pre_signature_part(ident, public_key=None, identifier=None):
|
||||
:return: A preset signature part
|
||||
"""
|
||||
|
||||
signature_method = ds.SignatureMethod(algorithm = ds.SIG_RSA_SHA1)
|
||||
signature_method = ds.SignatureMethod(algorithm=ds.SIG_RSA_SHA1)
|
||||
canonicalization_method = ds.CanonicalizationMethod(
|
||||
algorithm = ds.ALG_EXC_C14N)
|
||||
trans0 = ds.Transform(algorithm = ds.TRANSFORM_ENVELOPED)
|
||||
trans1 = ds.Transform(algorithm = ds.ALG_EXC_C14N)
|
||||
transforms = ds.Transforms(transform = [trans0, trans1])
|
||||
digest_method = ds.DigestMethod(algorithm = ds.DIGEST_SHA1)
|
||||
algorithm=ds.ALG_EXC_C14N)
|
||||
trans0 = ds.Transform(algorithm=ds.TRANSFORM_ENVELOPED)
|
||||
trans1 = ds.Transform(algorithm=ds.ALG_EXC_C14N)
|
||||
transforms = ds.Transforms(transform=[trans0, trans1])
|
||||
digest_method = ds.DigestMethod(algorithm=ds.DIGEST_SHA1)
|
||||
|
||||
reference = ds.Reference(uri = "#%s" % ident,
|
||||
digest_value = ds.DigestValue(),
|
||||
transforms = transforms,
|
||||
digest_method = digest_method)
|
||||
reference = ds.Reference(uri="#%s" % ident, digest_value=ds.DigestValue(),
|
||||
transforms=transforms, digest_method=digest_method)
|
||||
|
||||
signed_info = ds.SignedInfo(signature_method = signature_method,
|
||||
canonicalization_method = canonicalization_method,
|
||||
reference = reference)
|
||||
signed_info = ds.SignedInfo(signature_method=signature_method,
|
||||
canonicalization_method=canonicalization_method,
|
||||
reference=reference)
|
||||
|
||||
signature = ds.Signature(signed_info=signed_info,
|
||||
signature_value=ds.SignatureValue())
|
||||
signature = ds.Signature(signed_info=signed_info,
|
||||
signature_value=ds.SignatureValue())
|
||||
|
||||
if identifier:
|
||||
signature.id = "Signature%d" % identifier
|
||||
|
||||
if public_key:
|
||||
x509_data = ds.X509Data(x509_certificate=[ds.X509Certificate(
|
||||
text=public_key)])
|
||||
x509_data = ds.X509Data(
|
||||
x509_certificate=[ds.X509Certificate(text=public_key)])
|
||||
key_info = ds.KeyInfo(x509_data=x509_data)
|
||||
signature.key_info = key_info
|
||||
|
||||
|
@@ -149,16 +149,16 @@ def add_duration(tid, duration):
|
||||
if days < 1:
|
||||
pass
|
||||
elif days > maximum_day_in_month_for(year, month):
|
||||
days = days - maximum_day_in_month_for(year, month)
|
||||
days -= maximum_day_in_month_for(year, month)
|
||||
carry = 1
|
||||
else:
|
||||
break
|
||||
temp = month + carry
|
||||
month = modulo(temp, 1, 13)
|
||||
year = year + f_quotient(temp, 1, 13)
|
||||
year += f_quotient(temp, 1, 13)
|
||||
|
||||
return time.localtime(time.mktime((year, month, days, hour, minutes,
|
||||
secs, 0, 0, -1)))
|
||||
return time.localtime(time.mktime((year, month, days, hour, minutes,
|
||||
secs, 0, 0, -1)))
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -232,7 +232,7 @@ def str_to_time(timestr, format=TIME_FORMAT):
|
||||
return 0
|
||||
try:
|
||||
then = time.strptime(timestr, format)
|
||||
except Exception: # assume it's a format problem
|
||||
except ValueError, err: # assume it's a format problem
|
||||
try:
|
||||
elem = TIME_FORMAT_WITH_FRAGMENT.match(timestr)
|
||||
except Exception, exc:
|
||||
|
44
tests/test_70_redirect_signing.py
Normal file
44
tests/test_70_redirect_signing.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from saml2.pack import http_redirect_message
|
||||
from saml2.sigver import verify_redirect_signature
|
||||
from saml2.sigver import RSA_SHA1
|
||||
from saml2.server import Server
|
||||
from saml2 import BINDING_HTTP_REDIRECT
|
||||
from saml2.client import Saml2Client
|
||||
from saml2.config import SPConfig
|
||||
from saml2.sigver import rsa_load
|
||||
from urlparse import parse_qs
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
idp = Server(config_file="idp_all_conf")
|
||||
|
||||
conf = SPConfig()
|
||||
conf.load_file("servera_conf")
|
||||
sp = Saml2Client(conf)
|
||||
|
||||
def test():
|
||||
srvs = sp.metadata.single_sign_on_service(idp.config.entityid,
|
||||
BINDING_HTTP_REDIRECT)
|
||||
|
||||
destination = srvs[0]["location"]
|
||||
req = sp.create_authn_request(destination, id="id1")
|
||||
|
||||
try:
|
||||
key = sp.sec.key
|
||||
except AttributeError:
|
||||
key = rsa_load(sp.sec.key_file)
|
||||
|
||||
info = http_redirect_message(req, destination, relay_state="RS",
|
||||
typ="SAMLRequest", sigalg=RSA_SHA1, key=key)
|
||||
|
||||
verified_ok = False
|
||||
|
||||
for param, val in info["headers"]:
|
||||
if param == "Location":
|
||||
_dict = parse_qs(val.split("?")[1])
|
||||
_certs = idp.metadata.certs(sp.config.entityid, "any", "signing")
|
||||
for cert in _certs:
|
||||
if verify_redirect_signature(_dict, cert):
|
||||
verified_ok = True
|
||||
|
||||
assert verified_ok
|
Reference in New Issue
Block a user