Merge branch 'master' of https://github.com/rohe/pysaml2
This commit is contained in:
@@ -31,6 +31,7 @@ from saml2.saml import AUTHN_PASSWORD
|
|||||||
|
|
||||||
logger = logging.getLogger("saml2.idp")
|
logger = logging.getLogger("saml2.idp")
|
||||||
|
|
||||||
|
|
||||||
def _expiration(timeout, tformat="%a, %d-%b-%Y %H:%M:%S GMT"):
|
def _expiration(timeout, tformat="%a, %d-%b-%Y %H:%M:%S GMT"):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -50,7 +51,7 @@ def _expiration(timeout, tformat="%a, %d-%b-%Y %H:%M:%S GMT"):
|
|||||||
|
|
||||||
|
|
||||||
def dict2list_of_tuples(d):
|
def dict2list_of_tuples(d):
|
||||||
return [(k,v) for k,v in d.items()]
|
return [(k, v) for k, v in d.items()]
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ class Service(object):
|
|||||||
def unpack_redirect(self):
|
def unpack_redirect(self):
|
||||||
if "QUERY_STRING" in self.environ:
|
if "QUERY_STRING" in self.environ:
|
||||||
_qs = self.environ["QUERY_STRING"]
|
_qs = self.environ["QUERY_STRING"]
|
||||||
return dict([(k,v[0]) for k,v in parse_qs(_qs).items()])
|
return dict([(k, v[0]) for k, v in parse_qs(_qs).items()])
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -270,8 +271,13 @@ class SSO(Service):
|
|||||||
IDP.ticket[key] = _dict
|
IDP.ticket[key] = _dict
|
||||||
_resp = key
|
_resp = key
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
_resp = IDP.ticket[_dict["key"]]
|
_resp = IDP.ticket[_dict["key"]]
|
||||||
del IDP.ticket[_dict["key"]]
|
del IDP.ticket[_dict["key"]]
|
||||||
|
except KeyError:
|
||||||
|
key = sha1("%s" % _dict).hexdigest()
|
||||||
|
IDP.ticket[key] = _dict
|
||||||
|
_resp = key
|
||||||
|
|
||||||
return _resp
|
return _resp
|
||||||
|
|
||||||
@@ -391,8 +397,9 @@ def do_verify(environ, start_response, _):
|
|||||||
if not _ok:
|
if not _ok:
|
||||||
resp = Unauthorized("Unknown user or wrong password")
|
resp = Unauthorized("Unknown user or wrong password")
|
||||||
else:
|
else:
|
||||||
uid = rndstr()
|
uid = rndstr(24)
|
||||||
IDP.authn[uid] = user
|
IDP.uid2user[uid] = user
|
||||||
|
IDP.user2uid[user] = uid
|
||||||
logger.debug("Register %s under '%s'" % (user, uid))
|
logger.debug("Register %s under '%s'" % (user, uid))
|
||||||
kaka = set_cookie("idpauthn", "/", uid)
|
kaka = set_cookie("idpauthn", "/", uid)
|
||||||
lox = "http://%s%s?id=%s&key=%s" % (environ["HTTP_HOST"],
|
lox = "http://%s%s?id=%s&key=%s" % (environ["HTTP_HOST"],
|
||||||
@@ -437,6 +444,8 @@ class SLO(Service):
|
|||||||
if msg.name_id:
|
if msg.name_id:
|
||||||
lid = IDP.ident.find_local_id(msg.name_id)
|
lid = IDP.ident.find_local_id(msg.name_id)
|
||||||
logger.info("local identifier: %s" % lid)
|
logger.info("local identifier: %s" % lid)
|
||||||
|
del IDP.uid2user[IDP.user2uid[lid]]
|
||||||
|
del IDP.user2uid[lid]
|
||||||
# remove the authentication
|
# remove the authentication
|
||||||
try:
|
try:
|
||||||
IDP.remove_authn_statements(msg.name_id)
|
IDP.remove_authn_statements(msg.name_id)
|
||||||
@@ -445,7 +454,7 @@ class SLO(Service):
|
|||||||
resp = ServiceError("%s" % exc)
|
resp = ServiceError("%s" % exc)
|
||||||
return resp(self.environ, self.start_response)
|
return resp(self.environ, self.start_response)
|
||||||
|
|
||||||
resp = IDP.create_logout_response(msg)
|
resp = IDP.create_logout_response(msg, [binding])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hinfo = IDP.apply_binding(binding, "%s" % resp, "", relay_state)
|
hinfo = IDP.apply_binding(binding, "%s" % resp, "", relay_state)
|
||||||
@@ -454,11 +463,11 @@ class SLO(Service):
|
|||||||
resp = ServiceError("%s" % exc)
|
resp = ServiceError("%s" % exc)
|
||||||
return resp(self.environ, self.start_response)
|
return resp(self.environ, self.start_response)
|
||||||
|
|
||||||
logger.info("Header: %s" % (hinfo["headers"],))
|
|
||||||
#_tlh = dict2list_of_tuples(hinfo["headers"])
|
#_tlh = dict2list_of_tuples(hinfo["headers"])
|
||||||
delco = delete_cookie(self.environ, "idpauthn")
|
delco = delete_cookie(self.environ, "idpauthn")
|
||||||
if delco:
|
if delco:
|
||||||
hinfo["headers"].append(delco)
|
hinfo["headers"].append(delco)
|
||||||
|
logger.info("Header: %s" % (hinfo["headers"],))
|
||||||
resp = Response(hinfo["data"], headers=hinfo["headers"])
|
resp = Response(hinfo["data"], headers=hinfo["headers"])
|
||||||
return resp(self.environ, self.start_response)
|
return resp(self.environ, self.start_response)
|
||||||
|
|
||||||
@@ -475,8 +484,8 @@ class NMI(Service):
|
|||||||
request = req.message
|
request = req.message
|
||||||
|
|
||||||
# Do the necessary stuff
|
# Do the necessary stuff
|
||||||
name_id = IDP.ident.handle_manage_name_id_request(request.name_id,
|
name_id = IDP.ident.handle_manage_name_id_request(
|
||||||
request.new_id, request.new_encrypted_id,
|
request.name_id, request.new_id, request.new_encrypted_id,
|
||||||
request.terminate)
|
request.terminate)
|
||||||
|
|
||||||
logger.debug("New NameID: %s" % name_id)
|
logger.debug("New NameID: %s" % name_id)
|
||||||
@@ -606,8 +615,8 @@ class NIM(Service):
|
|||||||
request = req.message
|
request = req.message
|
||||||
# Do the necessary stuff
|
# Do the necessary stuff
|
||||||
try:
|
try:
|
||||||
name_id = IDP.ident.handle_name_id_mapping_request(request.name_id,
|
name_id = IDP.ident.handle_name_id_mapping_request(
|
||||||
request.name_id_policy)
|
request.name_id, request.name_id_policy)
|
||||||
except Unknown:
|
except Unknown:
|
||||||
resp = BadRequest("Unknown entity")
|
resp = BadRequest("Unknown entity")
|
||||||
return resp(self.environ, self.start_response)
|
return resp(self.environ, self.start_response)
|
||||||
@@ -638,7 +647,10 @@ def kaka2user(kaka):
|
|||||||
cookie_obj = SimpleCookie(kaka)
|
cookie_obj = SimpleCookie(kaka)
|
||||||
morsel = cookie_obj.get("idpauthn", None)
|
morsel = cookie_obj.get("idpauthn", None)
|
||||||
if morsel:
|
if morsel:
|
||||||
return IDP.authn[morsel.value]
|
try:
|
||||||
|
return IDP.uid2user[morsel.value]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
logger.debug("No idpauthn cookie")
|
logger.debug("No idpauthn cookie")
|
||||||
return None
|
return None
|
||||||
@@ -646,6 +658,7 @@ def kaka2user(kaka):
|
|||||||
|
|
||||||
def delete_cookie(environ, name):
|
def delete_cookie(environ, name):
|
||||||
kaka = environ.get("HTTP_COOKIE", '')
|
kaka = environ.get("HTTP_COOKIE", '')
|
||||||
|
logger.debug("delete KAKA: %s" % kaka)
|
||||||
if kaka:
|
if kaka:
|
||||||
cookie_obj = SimpleCookie(kaka)
|
cookie_obj = SimpleCookie(kaka)
|
||||||
morsel = cookie_obj.get(name, None)
|
morsel = cookie_obj.get(name, None)
|
||||||
@@ -739,7 +752,7 @@ def application(environ, start_response):
|
|||||||
try:
|
try:
|
||||||
query = parse_qs(environ["QUERY_STRING"])
|
query = parse_qs(environ["QUERY_STRING"])
|
||||||
logger.debug("QUERY: %s" % query)
|
logger.debug("QUERY: %s" % query)
|
||||||
user = IDP.authn[query["id"][0]]
|
user = IDP.uid2user[query["id"][0]]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
user = None
|
user = None
|
||||||
|
|
||||||
|
@@ -145,8 +145,10 @@ def logout(environ, start_response, user):
|
|||||||
# What if more than one
|
# What if more than one
|
||||||
_dict = client.saml_client.global_logout(subject_id)
|
_dict = client.saml_client.global_logout(subject_id)
|
||||||
logger.info("[logout] global_logout > %s" % (_dict,))
|
logger.info("[logout] global_logout > %s" % (_dict,))
|
||||||
|
rem = environ['repoze.who.plugins'][client.rememberer_name]
|
||||||
|
rem.forget(environ, subject_id)
|
||||||
|
|
||||||
for key, item in _dict.item():
|
for key, item in _dict.items():
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
binding, htargs = item
|
binding, htargs = item
|
||||||
else: # result from logout, should be OK
|
else: # result from logout, should be OK
|
||||||
@@ -200,13 +202,15 @@ def application(environ, start_response):
|
|||||||
request is done
|
request is done
|
||||||
:return: The response as a list of lines
|
:return: The response as a list of lines
|
||||||
"""
|
"""
|
||||||
|
path = environ.get('PATH_INFO', '').lstrip('/')
|
||||||
|
logger.info("<application> PATH: %s" % path)
|
||||||
|
|
||||||
user = environ.get("REMOTE_USER", "")
|
user = environ.get("REMOTE_USER", "")
|
||||||
if not user:
|
if not user:
|
||||||
user = environ.get("repoze.who.identity", "")
|
user = environ.get("repoze.who.identity", "")
|
||||||
|
logger.info("repoze.who.identity: '%s'" % user)
|
||||||
path = environ.get('PATH_INFO', '').lstrip('/')
|
else:
|
||||||
logger.info("<application> PATH: %s" % path)
|
logger.info("REMOTE_USER: '%s'" % user)
|
||||||
logger.info("logger name: %s" % logger.name)
|
|
||||||
#logger.info(logging.Logger.manager.loggerDict)
|
#logger.info(logging.Logger.manager.loggerDict)
|
||||||
for regex, callback in urls:
|
for regex, callback in urls:
|
||||||
if user:
|
if user:
|
||||||
|
@@ -5,14 +5,14 @@ BASE= "http://localhost:8087"
|
|||||||
#BASE= "http://lingon.catalogix.se:8087"
|
#BASE= "http://lingon.catalogix.se:8087"
|
||||||
|
|
||||||
CONFIG = {
|
CONFIG = {
|
||||||
"entityid" : "%s/sp.xml" % BASE,
|
"entityid": "%s/sp.xml" % BASE,
|
||||||
"description": "My SP",
|
"description": "My SP",
|
||||||
"service": {
|
"service": {
|
||||||
"sp":{
|
"sp": {
|
||||||
"name" : "Rolands SP",
|
"name" : "Rolands SP",
|
||||||
"endpoints":{
|
"endpoints": {
|
||||||
"assertion_consumer_service": [BASE],
|
"assertion_consumer_service": [BASE],
|
||||||
"single_logout_service" : [(BASE+"/slo",
|
"single_logout_service": [(BASE + "/slo",
|
||||||
BINDING_HTTP_REDIRECT)],
|
BINDING_HTTP_REDIRECT)],
|
||||||
},
|
},
|
||||||
"required_attributes": ["surname", "givenname",
|
"required_attributes": ["surname", "givenname",
|
||||||
@@ -20,18 +20,16 @@ CONFIG = {
|
|||||||
"optional_attributes": ["title"],
|
"optional_attributes": ["title"],
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"debug" : 1,
|
"debug": 1,
|
||||||
"key_file" : "pki/mykey.pem",
|
"key_file": "pki/mykey.pem",
|
||||||
"cert_file" : "pki/mycert.pem",
|
"cert_file": "pki/mycert.pem",
|
||||||
"attribute_map_dir" : "./attributemaps",
|
"attribute_map_dir": "./attributemaps",
|
||||||
"metadata" : {
|
"metadata": {"local": ["../idp2/idp.xml"]},
|
||||||
"local": ["../idp2/idp.xml"],
|
|
||||||
},
|
|
||||||
# -- below used by make_metadata --
|
# -- below used by make_metadata --
|
||||||
"organization": {
|
"organization": {
|
||||||
"name": "Exempel AB",
|
"name": "Exempel AB",
|
||||||
"display_name": [("Exempel AB","se"),("Example Co.","en")],
|
"display_name": [("Exempel AB", "se"), ("Example Co.", "en")],
|
||||||
"url":"http://www.example.com/roland",
|
"url": "http://www.example.com/roland",
|
||||||
},
|
},
|
||||||
"contact_person": [{
|
"contact_person": [{
|
||||||
"given_name":"John",
|
"given_name":"John",
|
||||||
|
@@ -14,9 +14,8 @@ reissue_time = 3000
|
|||||||
[plugin:saml2auth]
|
[plugin:saml2auth]
|
||||||
use = s2repoze.plugins.sp:make_plugin
|
use = s2repoze.plugins.sp:make_plugin
|
||||||
saml_conf = sp_conf
|
saml_conf = sp_conf
|
||||||
rememberer_name = auth_tkt
|
remember_name = auth_tkt
|
||||||
sid_store = outstanding
|
sid_store = outstanding
|
||||||
identity_cache = identities
|
|
||||||
|
|
||||||
[general]
|
[general]
|
||||||
request_classifier = s2repoze.plugins.challenge_decider:my_request_classifier
|
request_classifier = s2repoze.plugins.challenge_decider:my_request_classifier
|
||||||
|
@@ -42,6 +42,7 @@ from saml2 import ecp, BINDING_HTTP_REDIRECT
|
|||||||
from saml2 import BINDING_HTTP_POST
|
from saml2 import BINDING_HTTP_POST
|
||||||
|
|
||||||
from saml2.client import Saml2Client
|
from saml2.client import Saml2Client
|
||||||
|
from saml2.ident import code, decode
|
||||||
from saml2.s_utils import sid
|
from saml2.s_utils import sid
|
||||||
from saml2.config import config_factory
|
from saml2.config import config_factory
|
||||||
from saml2.profile import paos
|
from saml2.profile import paos
|
||||||
@@ -53,16 +54,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
PAOS_HEADER_INFO = 'ver="%s";"%s"' % (paos.NAMESPACE, ECP_SERVICE)
|
PAOS_HEADER_INFO = 'ver="%s";"%s"' % (paos.NAMESPACE, ECP_SERVICE)
|
||||||
|
|
||||||
|
|
||||||
def construct_came_from(environ):
|
def construct_came_from(environ):
|
||||||
""" The URL that the user used when the process where interupted
|
""" The URL that the user used when the process where interupted
|
||||||
for single-sign-on processing. """
|
for single-sign-on processing. """
|
||||||
|
|
||||||
came_from = environ.get("PATH_INFO")
|
came_from = environ.get("PATH_INFO")
|
||||||
qstr = environ.get("QUERY_STRING","")
|
qstr = environ.get("QUERY_STRING", "")
|
||||||
if qstr:
|
if qstr:
|
||||||
came_from += '?' + qstr
|
came_from += '?' + qstr
|
||||||
return came_from
|
return came_from
|
||||||
|
|
||||||
|
|
||||||
# FormPluginBase defines the methods remember and forget
|
# FormPluginBase defines the methods remember and forget
|
||||||
def cgi_field_storage_to_dict(field_storage):
|
def cgi_field_storage_to_dict(field_storage):
|
||||||
"""Get a plain dictionary, rather than the '.value' system used by the
|
"""Get a plain dictionary, rather than the '.value' system used by the
|
||||||
@@ -71,16 +74,15 @@ def cgi_field_storage_to_dict(field_storage):
|
|||||||
params = {}
|
params = {}
|
||||||
for key in field_storage.keys():
|
for key in field_storage.keys():
|
||||||
try:
|
try:
|
||||||
params[ key ] = field_storage[ key ].value
|
params[key] = field_storage[key].value
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if isinstance(field_storage[ key ], basestring):
|
if isinstance(field_storage[key], basestring):
|
||||||
params[key] = field_storage[key]
|
params[key] = field_storage[key]
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def get_body(environ):
|
|
||||||
body = ""
|
|
||||||
|
|
||||||
|
def get_body(environ):
|
||||||
length = int(environ["CONTENT_LENGTH"])
|
length = int(environ["CONTENT_LENGTH"])
|
||||||
try:
|
try:
|
||||||
body = environ["wsgi.input"].read(length)
|
body = environ["wsgi.input"].read(length)
|
||||||
@@ -95,11 +97,13 @@ def get_body(environ):
|
|||||||
|
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
def exception_trace(tag, exc, log):
|
def exception_trace(tag, exc, log):
|
||||||
message = traceback.format_exception(*sys.exc_info())
|
message = traceback.format_exception(*sys.exc_info())
|
||||||
log.error("[%s] ExcList: %s" % (tag, "".join(message),))
|
log.error("[%s] ExcList: %s" % (tag, "".join(message),))
|
||||||
log.error("[%s] Exception: %s" % (tag, exc))
|
log.error("[%s] Exception: %s" % (tag, exc))
|
||||||
|
|
||||||
|
|
||||||
class ECP_response(object):
|
class ECP_response(object):
|
||||||
code = 200
|
code = 200
|
||||||
title = 'OK'
|
title = 'OK'
|
||||||
@@ -114,13 +118,12 @@ class ECP_response(object):
|
|||||||
return [self.content]
|
return [self.content]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SAML2Plugin(FormPluginBase):
|
class SAML2Plugin(FormPluginBase):
|
||||||
|
|
||||||
implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider)
|
implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider)
|
||||||
|
|
||||||
def __init__(self, rememberer_name, config, saml_client,
|
def __init__(self, rememberer_name, config, saml_client, wayf, cache,
|
||||||
wayf, cache, debug, sid_store=None, discovery=""):
|
debug, sid_store=None, discovery=""):
|
||||||
FormPluginBase.__init__(self)
|
FormPluginBase.__init__(self)
|
||||||
|
|
||||||
self.rememberer_name = rememberer_name
|
self.rememberer_name = rememberer_name
|
||||||
@@ -148,8 +151,6 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
:param environ: A dictionary with environment variables
|
:param environ: A dictionary with environment variables
|
||||||
"""
|
"""
|
||||||
|
|
||||||
post = {}
|
|
||||||
|
|
||||||
post_env = environ.copy()
|
post_env = environ.copy()
|
||||||
post_env['QUERY_STRING'] = ''
|
post_env['QUERY_STRING'] = ''
|
||||||
|
|
||||||
@@ -173,7 +174,7 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
sid_ = sid()
|
sid_ = sid()
|
||||||
self.outstanding_queries[sid_] = came_from
|
self.outstanding_queries[sid_] = came_from
|
||||||
logger.info("Redirect to WAYF function: %s" % self.wayf)
|
logger.info("Redirect to WAYF function: %s" % self.wayf)
|
||||||
return -1, HTTPSeeOther(headers = [('Location',
|
return -1, HTTPSeeOther(headers=[('Location',
|
||||||
"%s?%s" % (self.wayf, sid_))])
|
"%s?%s" % (self.wayf, sid_))])
|
||||||
|
|
||||||
#noinspection PyUnusedLocal
|
#noinspection PyUnusedLocal
|
||||||
@@ -189,6 +190,8 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
# 'PAOS' : 'ver="%s";"%s"' % (paos.NAMESPACE, SERVICE)
|
# 'PAOS' : 'ver="%s";"%s"' % (paos.NAMESPACE, SERVICE)
|
||||||
# }
|
# }
|
||||||
|
|
||||||
|
_cli = self.saml_client
|
||||||
|
|
||||||
logger.info("[_pick_idp] %s" % environ)
|
logger.info("[_pick_idp] %s" % environ)
|
||||||
if "HTTP_PAOS" in environ:
|
if "HTTP_PAOS" in environ:
|
||||||
if environ["HTTP_PAOS"] == PAOS_HEADER_INFO:
|
if environ["HTTP_PAOS"] == PAOS_HEADER_INFO:
|
||||||
@@ -200,14 +203,13 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
logger.info("- ECP client detected -")
|
logger.info("- ECP client detected -")
|
||||||
|
|
||||||
_relay_state = construct_came_from(environ)
|
_relay_state = construct_came_from(environ)
|
||||||
_entityid = self.saml_client.config.ecp_endpoint(
|
_entityid = _cli.config.ecp_endpoint(environ["REMOTE_ADDR"])
|
||||||
environ["REMOTE_ADDR"])
|
|
||||||
if not _entityid:
|
if not _entityid:
|
||||||
return -1, HTTPInternalServerError(
|
return -1, HTTPInternalServerError(
|
||||||
detail="No IdP to talk to"
|
detail="No IdP to talk to")
|
||||||
)
|
|
||||||
logger.info("IdP to talk to: %s" % _entityid)
|
logger.info("IdP to talk to: %s" % _entityid)
|
||||||
return ecp.ecp_auth_request(self.saml_client, _entityid,
|
return ecp.ecp_auth_request(_cli, _entityid,
|
||||||
_relay_state)
|
_relay_state)
|
||||||
else:
|
else:
|
||||||
return -1, HTTPInternalServerError(
|
return -1, HTTPInternalServerError(
|
||||||
@@ -216,12 +218,11 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
return -1, HTTPInternalServerError(
|
return -1, HTTPInternalServerError(
|
||||||
detail='unknown ECP version')
|
detail='unknown ECP version')
|
||||||
|
|
||||||
|
|
||||||
idps = self.metadata.with_descriptor("idpsso")
|
idps = self.metadata.with_descriptor("idpsso")
|
||||||
|
|
||||||
logger.info("IdP URL: %s" % idps)
|
logger.info("IdP URL: %s" % idps)
|
||||||
|
|
||||||
if len( idps ) == 1:
|
if len(idps) == 1:
|
||||||
# idps is a dictionary
|
# idps is a dictionary
|
||||||
idp_entity_id = idps.keys()[0]
|
idp_entity_id = idps.keys()[0]
|
||||||
elif not len(idps):
|
elif not len(idps):
|
||||||
@@ -229,16 +230,17 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
else:
|
else:
|
||||||
idp_entity_id = ""
|
idp_entity_id = ""
|
||||||
logger.info("ENVIRON: %s" % environ)
|
logger.info("ENVIRON: %s" % environ)
|
||||||
query = environ.get('s2repoze.body','')
|
query = environ.get('s2repoze.body', '')
|
||||||
if not query:
|
if not query:
|
||||||
query = environ.get("QUERY_STRING","")
|
query = environ.get("QUERY_STRING", "")
|
||||||
|
|
||||||
logger.info("<_pick_idp> query: %s" % query)
|
logger.info("<_pick_idp> query: %s" % query)
|
||||||
|
|
||||||
if self.wayf:
|
if self.wayf:
|
||||||
if query:
|
if query:
|
||||||
try:
|
try:
|
||||||
wayf_selected = dict(parse_qs(query))["wayf_selected"][0]
|
wayf_selected = dict(parse_qs(query))[
|
||||||
|
"wayf_selected"][0]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return self._wayf_redirect(came_from)
|
return self._wayf_redirect(came_from)
|
||||||
idp_entity_id = wayf_selected
|
idp_entity_id = wayf_selected
|
||||||
@@ -246,16 +248,16 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
return self._wayf_redirect(came_from)
|
return self._wayf_redirect(came_from)
|
||||||
elif self.discosrv:
|
elif self.discosrv:
|
||||||
if query:
|
if query:
|
||||||
idp_entity_id = self.saml_client.parse_discovery_service_response(
|
idp_entity_id = _cli.parse_discovery_service_response(
|
||||||
query=environ.get("QUERY_STRING"))
|
query=environ.get("QUERY_STRING"))
|
||||||
else:
|
else:
|
||||||
sid_ = sid()
|
sid_ = sid()
|
||||||
self.outstanding_queries[sid_] = came_from
|
self.outstanding_queries[sid_] = came_from
|
||||||
logger.info("Redirect to Discovery Service function")
|
logger.info("Redirect to Discovery Service function")
|
||||||
eid = self.saml_client.config.entity_id
|
eid = _cli.config.entity_id
|
||||||
loc = self.saml_client.create_discovery_service_request(
|
loc = _cli.create_discovery_service_request(self.discosrv,
|
||||||
self.discosrv, eid)
|
eid)
|
||||||
return -1, HTTPSeeOther(headers = [('Location',loc)])
|
return -1, HTTPSeeOther(headers=[('Location', loc)])
|
||||||
else:
|
else:
|
||||||
return -1, HTTPNotImplemented(detail='No WAYF or DJ present!')
|
return -1, HTTPNotImplemented(detail='No WAYF or DJ present!')
|
||||||
|
|
||||||
@@ -266,8 +268,10 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
#noinspection PyUnusedLocal
|
#noinspection PyUnusedLocal
|
||||||
def challenge(self, environ, _status, _app_headers, _forget_headers):
|
def challenge(self, environ, _status, _app_headers, _forget_headers):
|
||||||
|
|
||||||
# this challenge consist in login out
|
_cli = self.saml_client
|
||||||
if environ.has_key('rwpc.logout'):
|
|
||||||
|
# this challenge consist in logging out
|
||||||
|
if 'rwpc.logout' in environ:
|
||||||
# ignore right now?
|
# ignore right now?
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -283,7 +287,7 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
vorg_name = environ["myapp.vo"]
|
vorg_name = environ["myapp.vo"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
try:
|
try:
|
||||||
vorg_name = self.saml_client.vorg._name
|
vorg_name = _cli.vorg._name
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
vorg_name = ""
|
vorg_name = ""
|
||||||
|
|
||||||
@@ -304,7 +308,6 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
entity_id = response
|
entity_id = response
|
||||||
logger.info("[sp.challenge] entity_id: %s" % entity_id)
|
logger.info("[sp.challenge] entity_id: %s" % entity_id)
|
||||||
# Do the AuthnRequest
|
# Do the AuthnRequest
|
||||||
_cli = self.saml_client
|
|
||||||
_binding = BINDING_HTTP_REDIRECT
|
_binding = BINDING_HTTP_REDIRECT
|
||||||
try:
|
try:
|
||||||
srvs = _cli.metadata.single_sign_on_service(entity_id, _binding)
|
srvs = _cli.metadata.single_sign_on_service(entity_id, _binding)
|
||||||
@@ -319,7 +322,8 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
logger.debug("ht_args: %s" % ht_args)
|
logger.debug("ht_args: %s" % ht_args)
|
||||||
except Exception, exc:
|
except Exception, exc:
|
||||||
logger.exception(exc)
|
logger.exception(exc)
|
||||||
raise Exception("Failed to construct the AuthnRequest: %s" % exc)
|
raise Exception(
|
||||||
|
"Failed to construct the AuthnRequest: %s" % exc)
|
||||||
|
|
||||||
# remember the request
|
# remember the request
|
||||||
self.outstanding_queries[_sid] = came_from
|
self.outstanding_queries[_sid] = came_from
|
||||||
@@ -327,14 +331,15 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
if not ht_args["data"] and ht_args["headers"][0][0] == "Location":
|
if not ht_args["data"] and ht_args["headers"][0][0] == "Location":
|
||||||
logger.debug('redirect to: %s' % ht_args["headers"][0][1])
|
logger.debug('redirect to: %s' % ht_args["headers"][0][1])
|
||||||
return HTTPSeeOther(headers=ht_args["headers"])
|
return HTTPSeeOther(headers=ht_args["headers"])
|
||||||
else :
|
else:
|
||||||
return ht_args["data"]
|
return ht_args["data"]
|
||||||
|
|
||||||
def _construct_identity(self, session_info):
|
def _construct_identity(self, session_info):
|
||||||
|
cni = code(session_info["name_id"])
|
||||||
identity = {
|
identity = {
|
||||||
"login": session_info["name_id"],
|
"login": cni,
|
||||||
"password": "",
|
"password": "",
|
||||||
'repoze.who.userid': session_info["name_id"],
|
'repoze.who.userid': cni,
|
||||||
"user": session_info["ava"],
|
"user": session_info["ava"],
|
||||||
}
|
}
|
||||||
logger.debug("Identity: %s" % identity)
|
logger.debug("Identity: %s" % identity)
|
||||||
@@ -349,8 +354,7 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
# Evaluate the response, returns a AuthnResponse instance
|
# Evaluate the response, returns a AuthnResponse instance
|
||||||
try:
|
try:
|
||||||
authresp = self.saml_client.parse_authn_request_response(
|
authresp = self.saml_client.parse_authn_request_response(
|
||||||
post["SAMLResponse"],
|
post["SAMLResponse"], BINDING_HTTP_POST,
|
||||||
BINDING_HTTP_POST,
|
|
||||||
self.outstanding_queries)
|
self.outstanding_queries)
|
||||||
except Exception, excp:
|
except Exception, excp:
|
||||||
logger.exception("Exception: %s" % (excp,))
|
logger.exception("Exception: %s" % (excp,))
|
||||||
@@ -413,7 +417,7 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not post.has_key("SAMLResponse"):
|
if "SAMLResponse" not in post:
|
||||||
logger.info("[sp.identify] --- NOT SAMLResponse ---")
|
logger.info("[sp.identify] --- NOT SAMLResponse ---")
|
||||||
# Not for me, put the post back where next in line can
|
# Not for me, put the post back where next in line can
|
||||||
# find it
|
# find it
|
||||||
@@ -424,8 +428,8 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
# check for SAML2 authN response
|
# check for SAML2 authN response
|
||||||
#if self.debug:
|
#if self.debug:
|
||||||
try:
|
try:
|
||||||
session_info = self._eval_authn_response(environ,
|
session_info = self._eval_authn_response(
|
||||||
cgi_field_storage_to_dict(post))
|
environ, cgi_field_storage_to_dict(post))
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
except TypeError, exc:
|
except TypeError, exc:
|
||||||
@@ -445,37 +449,28 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
|
|
||||||
if session_info:
|
if session_info:
|
||||||
environ["s2repoze.sessioninfo"] = session_info
|
environ["s2repoze.sessioninfo"] = session_info
|
||||||
name_id = session_info["name_id"]
|
return self._construct_identity(session_info)
|
||||||
# contruct and return the identity
|
|
||||||
identity = {
|
|
||||||
"login": name_id,
|
|
||||||
"password": "",
|
|
||||||
'repoze.who.userid': name_id,
|
|
||||||
"user": self.saml_client.users.get_identity(name_id)[0],
|
|
||||||
}
|
|
||||||
logger.info("[sp.identify] IDENTITY: %s" % (identity,))
|
|
||||||
return identity
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# IMetadataProvider
|
# IMetadataProvider
|
||||||
def add_metadata(self, environ, identity):
|
def add_metadata(self, environ, identity):
|
||||||
""" Add information to the knowledge I have about the user """
|
""" Add information to the knowledge I have about the user """
|
||||||
subject_id = identity['repoze.who.userid']
|
name_id = identity['repoze.who.userid']
|
||||||
#logger = environ.get('repoze.who.logger','')
|
if isinstance(name_id, basestring):
|
||||||
|
name_id = decode(name_id)
|
||||||
|
|
||||||
_cli = self.saml_client
|
_cli = self.saml_client
|
||||||
logger.debug("[add_metadata] for %s" % subject_id)
|
logger.debug("[add_metadata] for %s" % name_id)
|
||||||
try:
|
try:
|
||||||
logger.debug("Issuers: %s" % _cli.users.sources(subject_id))
|
logger.debug("Issuers: %s" % _cli.users.sources(name_id))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if "user" not in identity:
|
if "user" not in identity:
|
||||||
identity["user"] = {}
|
identity["user"] = {}
|
||||||
try:
|
try:
|
||||||
(ava, _) = _cli.users.get_identity(subject_id)
|
(ava, _) = _cli.users.get_identity(name_id)
|
||||||
#now = time.gmtime()
|
#now = time.gmtime()
|
||||||
logger.debug("[add_metadata] adds: %s" % ava)
|
logger.debug("[add_metadata] adds: %s" % ava)
|
||||||
identity["user"].update(ava)
|
identity["user"].update(ava)
|
||||||
@@ -486,9 +481,9 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
# is this a Virtual Organization situation
|
# is this a Virtual Organization situation
|
||||||
for vo in _cli.vorg.values():
|
for vo in _cli.vorg.values():
|
||||||
try:
|
try:
|
||||||
if vo.do_aggregation(subject_id):
|
if vo.do_aggregation(name_id):
|
||||||
# Get the extended identity
|
# Get the extended identity
|
||||||
identity["user"] = _cli.users.get_identity(subject_id)[0]
|
identity["user"] = _cli.users.get_identity(name_id)[0]
|
||||||
# Only do this once, mark that the identity has been
|
# Only do this once, mark that the identity has been
|
||||||
# expanded
|
# expanded
|
||||||
identity["pysaml2_vo_expanded"] = 1
|
identity["pysaml2_vo_expanded"] = 1
|
||||||
@@ -505,7 +500,7 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
# used 2 times : one to get the ticket, the other to validate it
|
# used 2 times : one to get the ticket, the other to validate it
|
||||||
def _service_url(self, environ, qstr=None):
|
def _service_url(self, environ, qstr=None):
|
||||||
if qstr is not None:
|
if qstr is not None:
|
||||||
url = construct_url(environ, querystring = qstr)
|
url = construct_url(environ, querystring=qstr)
|
||||||
else:
|
else:
|
||||||
url = construct_url(environ)
|
url = construct_url(environ)
|
||||||
return url
|
return url
|
||||||
@@ -519,8 +514,8 @@ class SAML2Plugin(FormPluginBase):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def make_plugin(rememberer_name=None, # plugin for remember
|
def make_plugin(remember_name=None, # plugin for remember
|
||||||
cache= "", # cache
|
cache="", # cache
|
||||||
# Which virtual organization to support
|
# Which virtual organization to support
|
||||||
virtual_organization="",
|
virtual_organization="",
|
||||||
saml_conf="",
|
saml_conf="",
|
||||||
@@ -534,17 +529,14 @@ def make_plugin(rememberer_name=None, # plugin for remember
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'must include saml_conf in configuration')
|
'must include saml_conf in configuration')
|
||||||
|
|
||||||
if rememberer_name is None:
|
if remember_name is None:
|
||||||
raise ValueError(
|
raise ValueError('must include remember_name in configuration')
|
||||||
'must include rememberer_name in configuration')
|
|
||||||
|
|
||||||
conf = config_factory("sp", saml_conf)
|
conf = config_factory("sp", saml_conf)
|
||||||
|
|
||||||
scl = Saml2Client(config=conf, identity_cache=identity_cache,
|
scl = Saml2Client(config=conf, identity_cache=identity_cache,
|
||||||
virtual_organization=virtual_organization)
|
virtual_organization=virtual_organization)
|
||||||
|
|
||||||
plugin = SAML2Plugin(rememberer_name, conf, scl, wayf, cache, sid_store,
|
plugin = SAML2Plugin(remember_name, conf, scl, wayf, cache, sid_store,
|
||||||
discovery)
|
discovery)
|
||||||
return plugin
|
return plugin
|
||||||
|
|
||||||
|
|
||||||
|
@@ -18,6 +18,7 @@
|
|||||||
"""Contains classes and functions that a SAML2.0 Service Provider (SP) may use
|
"""Contains classes and functions that a SAML2.0 Service Provider (SP) may use
|
||||||
to conclude its tasks.
|
to conclude its tasks.
|
||||||
"""
|
"""
|
||||||
|
from saml2.ident import decode
|
||||||
from saml2.httpbase import HTTPError
|
from saml2.httpbase import HTTPError
|
||||||
from saml2.s_utils import sid
|
from saml2.s_utils import sid
|
||||||
import saml2
|
import saml2
|
||||||
@@ -99,11 +100,14 @@ class Saml2Client(Base):
|
|||||||
conversation.
|
conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(name_id, basestring):
|
||||||
|
name_id = decode(name_id)
|
||||||
|
|
||||||
logger.info("logout request for: %s" % name_id)
|
logger.info("logout request for: %s" % name_id)
|
||||||
|
|
||||||
# find out which IdPs/AAs I should notify
|
# find out which IdPs/AAs I should notify
|
||||||
entity_ids = self.users.issuers_of_info(name_id)
|
entity_ids = self.users.issuers_of_info(name_id)
|
||||||
|
self.users.remove_person(name_id)
|
||||||
return self.do_logout(name_id, entity_ids, reason, expire, sign)
|
return self.do_logout(name_id, entity_ids, reason, expire, sign)
|
||||||
|
|
||||||
def do_logout(self, name_id, entity_ids, reason, expire, sign=None):
|
def do_logout(self, name_id, entity_ids, reason, expire, sign=None):
|
||||||
|
@@ -204,12 +204,12 @@ class Entity(HTTPBase):
|
|||||||
|
|
||||||
raise Exception("Unkown entity or unsupported bindings")
|
raise Exception("Unkown entity or unsupported bindings")
|
||||||
|
|
||||||
def message_args(self, id=0):
|
def message_args(self, seid=0):
|
||||||
if not id:
|
if not seid:
|
||||||
id = sid(self.seed)
|
seid = sid(self.seed)
|
||||||
|
|
||||||
return {"id":id, "version":VERSION,
|
return {"id": seid, "version": VERSION,
|
||||||
"issue_instant":instant(), "issuer":self._issuer()}
|
"issue_instant": instant(), "issuer": self._issuer()}
|
||||||
|
|
||||||
def response_args(self, message, bindings=None, descr_type=""):
|
def response_args(self, message, bindings=None, descr_type=""):
|
||||||
info = {"in_response_to": message.id}
|
info = {"in_response_to": message.id}
|
||||||
@@ -278,8 +278,7 @@ class Entity(HTTPBase):
|
|||||||
:param text: The SOAP message
|
:param text: The SOAP message
|
||||||
:return: A dictionary with two keys "body" and "header"
|
:return: A dictionary with two keys "body" and "header"
|
||||||
"""
|
"""
|
||||||
return class_instances_from_soap_enveloped_saml_thingies(text,
|
return class_instances_from_soap_enveloped_saml_thingies(text, [paos,
|
||||||
[paos,
|
|
||||||
ecp,
|
ecp,
|
||||||
samlp])
|
samlp])
|
||||||
|
|
||||||
@@ -367,8 +366,8 @@ class Entity(HTTPBase):
|
|||||||
_issuer = self._issuer(issuer)
|
_issuer = self._issuer(issuer)
|
||||||
|
|
||||||
response = response_factory(issuer=_issuer,
|
response = response_factory(issuer=_issuer,
|
||||||
in_response_to = in_response_to,
|
in_response_to=in_response_to,
|
||||||
status = status)
|
status=status)
|
||||||
|
|
||||||
if consumer_url:
|
if consumer_url:
|
||||||
response.destination = consumer_url
|
response.destination = consumer_url
|
||||||
@@ -377,7 +376,7 @@ class Entity(HTTPBase):
|
|||||||
setattr(response, key, val)
|
setattr(response, key, val)
|
||||||
|
|
||||||
if sign:
|
if sign:
|
||||||
self.sign(response,to_sign=to_sign)
|
self.sign(response, to_sign=to_sign)
|
||||||
elif to_sign:
|
elif to_sign:
|
||||||
return signed_instance_factory(response, self.sec, to_sign)
|
return signed_instance_factory(response, self.sec, to_sign)
|
||||||
else:
|
else:
|
||||||
@@ -689,8 +688,8 @@ class Entity(HTTPBase):
|
|||||||
if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]:
|
if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]:
|
||||||
try:
|
try:
|
||||||
# expected return address
|
# expected return address
|
||||||
kwargs["return_addr"] = self.config.endpoint(service,
|
kwargs["return_addr"] = self.config.endpoint(
|
||||||
binding=binding)[0]
|
service, binding=binding)[0]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info("Not supposed to handle this!")
|
logger.info("Not supposed to handle this!")
|
||||||
return None
|
return None
|
||||||
|
@@ -19,8 +19,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
__author__ = 'rolandh'
|
__author__ = 'rolandh'
|
||||||
|
|
||||||
ATTRS = {"version":None,
|
ATTRS = {"version": None,
|
||||||
"name":"",
|
"name": "",
|
||||||
"value": None,
|
"value": None,
|
||||||
"port": None,
|
"port": None,
|
||||||
"port_specified": False,
|
"port_specified": False,
|
||||||
@@ -43,12 +43,15 @@ PAIRS = {
|
|||||||
"path": "path_specified"
|
"path": "path_specified"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConnectionError(Exception):
|
class ConnectionError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HTTPError(Exception):
|
class HTTPError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _since_epoch(cdate):
|
def _since_epoch(cdate):
|
||||||
"""
|
"""
|
||||||
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
|
:param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT'
|
||||||
@@ -67,16 +70,19 @@ def _since_epoch(cdate):
|
|||||||
#return int(time.mktime(t))
|
#return int(time.mktime(t))
|
||||||
return calendar.timegm(t)
|
return calendar.timegm(t)
|
||||||
|
|
||||||
|
|
||||||
def set_list2dict(sl):
|
def set_list2dict(sl):
|
||||||
return dict(sl)
|
return dict(sl)
|
||||||
|
|
||||||
|
|
||||||
def dict2set_list(dic):
|
def dict2set_list(dic):
|
||||||
return [(k,v) for k,v in dic.items()]
|
return [(k, v) for k, v in dic.items()]
|
||||||
|
|
||||||
|
|
||||||
class HTTPBase(object):
|
class HTTPBase(object):
|
||||||
def __init__(self, verify=True, ca_bundle=None, key_file=None,
|
def __init__(self, verify=True, ca_bundle=None, key_file=None,
|
||||||
cert_file=None):
|
cert_file=None):
|
||||||
self.request_args = {"allow_redirects": False,}
|
self.request_args = {"allow_redirects": False}
|
||||||
#self.cookies = {}
|
#self.cookies = {}
|
||||||
self.cookiejar = cookielib.CookieJar()
|
self.cookiejar = cookielib.CookieJar()
|
||||||
|
|
||||||
@@ -98,6 +104,12 @@ class HTTPBase(object):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
part = urlparse.urlparse(url)
|
part = urlparse.urlparse(url)
|
||||||
|
|
||||||
|
#if part.port:
|
||||||
|
# _domain = "%s:%s" % (part.hostname, part.port)
|
||||||
|
#else:
|
||||||
|
_domain = part.hostname
|
||||||
|
|
||||||
cookie_dict = {}
|
cookie_dict = {}
|
||||||
now = utc_now()
|
now = utc_now()
|
||||||
for _, a in list(self.cookiejar._cookies.items()):
|
for _, a in list(self.cookiejar._cookies.items()):
|
||||||
@@ -106,6 +118,8 @@ class HTTPBase(object):
|
|||||||
# print cookie
|
# print cookie
|
||||||
if cookie.expires and cookie.expires <= now:
|
if cookie.expires and cookie.expires <= now:
|
||||||
continue
|
continue
|
||||||
|
if not re.search("%s$" % cookie.domain, _domain):
|
||||||
|
continue
|
||||||
if not re.match(cookie.path, part.path):
|
if not re.match(cookie.path, part.path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -116,8 +130,13 @@ class HTTPBase(object):
|
|||||||
def set_cookie(self, kaka, request):
|
def set_cookie(self, kaka, request):
|
||||||
"""Returns a cookielib.Cookie based on a set-cookie header line"""
|
"""Returns a cookielib.Cookie based on a set-cookie header line"""
|
||||||
|
|
||||||
# default rfc2109=False
|
if not kaka:
|
||||||
# max-age, httponly
|
return
|
||||||
|
|
||||||
|
part = urlparse.urlparse(request.url)
|
||||||
|
_domain = part.hostname
|
||||||
|
logger.debug("%s: '%s'" % (_domain, kaka))
|
||||||
|
|
||||||
for cookie_name, morsel in kaka.items():
|
for cookie_name, morsel in kaka.items():
|
||||||
std_attr = ATTRS.copy()
|
std_attr = ATTRS.copy()
|
||||||
std_attr["name"] = cookie_name
|
std_attr["name"] = cookie_name
|
||||||
@@ -133,9 +152,9 @@ class HTTPBase(object):
|
|||||||
if attr in ATTRS:
|
if attr in ATTRS:
|
||||||
if morsel[attr]:
|
if morsel[attr]:
|
||||||
if attr == "expires":
|
if attr == "expires":
|
||||||
std_attr[attr]=_since_epoch(morsel[attr])
|
std_attr[attr] = _since_epoch(morsel[attr])
|
||||||
else:
|
else:
|
||||||
std_attr[attr]=morsel[attr]
|
std_attr[attr] = morsel[attr]
|
||||||
elif attr == "max-age":
|
elif attr == "max-age":
|
||||||
if morsel["max-age"]:
|
if morsel["max-age"]:
|
||||||
std_attr["expires"] = _since_epoch(morsel["max-age"])
|
std_attr["expires"] = _since_epoch(morsel["max-age"])
|
||||||
@@ -144,8 +163,12 @@ class HTTPBase(object):
|
|||||||
if std_attr[att]:
|
if std_attr[att]:
|
||||||
std_attr[item] = True
|
std_attr[item] = True
|
||||||
|
|
||||||
if std_attr["domain"] and std_attr["domain"].startswith("."):
|
if std_attr["domain"]:
|
||||||
|
if std_attr["domain"].startswith("."):
|
||||||
std_attr["domain_initial_dot"] = True
|
std_attr["domain_initial_dot"] = True
|
||||||
|
else:
|
||||||
|
std_attr["domain"] = _domain
|
||||||
|
std_attr["domain_specified"] = True
|
||||||
|
|
||||||
if morsel["max-age"] is 0:
|
if morsel["max-age"] is 0:
|
||||||
try:
|
try:
|
||||||
@@ -154,6 +177,13 @@ class HTTPBase(object):
|
|||||||
name=std_attr["name"])
|
name=std_attr["name"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
elif morsel["expires"] < utc_now():
|
||||||
|
try:
|
||||||
|
self.cookiejar.clear(domain=std_attr["domain"],
|
||||||
|
path=std_attr["path"],
|
||||||
|
name=std_attr["name"])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
new_cookie = cookielib.Cookie(**std_attr)
|
new_cookie = cookielib.Cookie(**std_attr)
|
||||||
self.cookiejar.set_cookie(new_cookie)
|
self.cookiejar.set_cookie(new_cookie)
|
||||||
@@ -164,21 +194,22 @@ class HTTPBase(object):
|
|||||||
_kwargs.update(kwargs)
|
_kwargs.update(kwargs)
|
||||||
|
|
||||||
if self.cookiejar:
|
if self.cookiejar:
|
||||||
_kwargs["cookies"] = self.cookies(url)
|
_cd = self.cookies(url)
|
||||||
|
if _cd:
|
||||||
|
_kwargs["cookies"] = _cd
|
||||||
|
logger.debug("Sent cookies: %s" % _kwargs["cookies"])
|
||||||
|
|
||||||
if self.user and self.passwd:
|
if self.user and self.passwd:
|
||||||
_kwargs["auth"]= (self.user, self.passwd)
|
_kwargs["auth"] = (self.user, self.passwd)
|
||||||
|
|
||||||
#logger.info("SENT COOKIEs: %s" % (_kwargs["cookies"],))
|
|
||||||
try:
|
try:
|
||||||
r = requests.request(method, url, **_kwargs)
|
r = requests.request(method, url, **_kwargs)
|
||||||
except requests.ConnectionError, exc:
|
except requests.ConnectionError, exc:
|
||||||
raise ConnectionError("%s" % exc)
|
raise ConnectionError("%s" % exc)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
#logger.info("RECEIVED COOKIEs: %s" % (r.headers["set-cookie"],))
|
|
||||||
self.set_cookie(SimpleCookie(r.headers["set-cookie"]), r)
|
self.set_cookie(SimpleCookie(r.headers["set-cookie"]), r)
|
||||||
except AttributeError, err:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return r
|
return r
|
||||||
@@ -196,7 +227,7 @@ class HTTPBase(object):
|
|||||||
:return: dictionary
|
:return: dictionary
|
||||||
"""
|
"""
|
||||||
if not isinstance(message, basestring):
|
if not isinstance(message, basestring):
|
||||||
request = "%s" % (message,)
|
message = "%s" % (message,)
|
||||||
|
|
||||||
return http_form_post_message(message, destination, relay_state, typ)
|
return http_form_post_message(message, destination, relay_state, typ)
|
||||||
|
|
||||||
@@ -213,7 +244,7 @@ class HTTPBase(object):
|
|||||||
:return: dictionary
|
:return: dictionary
|
||||||
"""
|
"""
|
||||||
if not isinstance(message, basestring):
|
if not isinstance(message, basestring):
|
||||||
request = "%s" % (message,)
|
message = "%s" % (message,)
|
||||||
|
|
||||||
return http_redirect_message(message, destination, relay_state, typ)
|
return http_redirect_message(message, destination, relay_state, typ)
|
||||||
|
|
||||||
@@ -278,7 +309,7 @@ class HTTPBase(object):
|
|||||||
soap_message = _signed
|
soap_message = _signed
|
||||||
|
|
||||||
return {"url": destination, "method": "POST",
|
return {"url": destination, "method": "POST",
|
||||||
"data":soap_message, "headers":headers}
|
"data": soap_message, "headers": headers}
|
||||||
|
|
||||||
def send_using_soap(self, request, destination, headers=None, sign=False):
|
def send_using_soap(self, request, destination, headers=None, sign=False):
|
||||||
"""
|
"""
|
||||||
@@ -304,7 +335,7 @@ class HTTPBase(object):
|
|||||||
logger.info("SOAP response: %s" % response.text)
|
logger.info("SOAP response: %s" % response.text)
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
raise HTTPError("%d:%s" % (response.status_code, response.error))
|
raise HTTPError("%d:%s" % (response.status_code, response.content))
|
||||||
|
|
||||||
def add_credentials(self, user, passwd):
|
def add_credentials(self, user, passwd):
|
||||||
self.user = user
|
self.user = user
|
||||||
|
@@ -202,6 +202,7 @@ class MetaData(object):
|
|||||||
res[srv["binding"]].append(srv)
|
res[srv["binding"]].append(srv)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
res[srv["binding"]] = [srv]
|
res[srv["binding"]] = [srv]
|
||||||
|
logger.debug("_service => %s" % res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _ext_service(self, entity_id, typ, service, binding):
|
def _ext_service(self, entity_id, typ, service, binding):
|
||||||
|
@@ -28,7 +28,13 @@ import saml2
|
|||||||
import base64
|
import base64
|
||||||
import urllib
|
import urllib
|
||||||
from saml2.s_utils import deflate_and_base64_encode
|
from saml2.s_utils import deflate_and_base64_encode
|
||||||
|
from saml2.s_utils import Unsupported
|
||||||
import logging
|
import logging
|
||||||
|
from saml2.sigver import RSA_SHA1
|
||||||
|
from saml2.sigver import REQ_ORDER
|
||||||
|
from saml2.sigver import RESP_ORDER
|
||||||
|
from saml2.sigver import RSASigner
|
||||||
|
from saml2.sigver import sha1_digest
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,7 +57,9 @@ FORM_SPEC = """<form method="post" action="%s">
|
|||||||
<input type="hidden" name="RelayState" value="%s" />
|
<input type="hidden" name="RelayState" value="%s" />
|
||||||
</form>"""
|
</form>"""
|
||||||
|
|
||||||
def http_form_post_message(message, location, relay_state="", typ="SAMLRequest"):
|
|
||||||
|
def http_form_post_message(message, location, relay_state="",
|
||||||
|
typ="SAMLRequest"):
|
||||||
"""The HTTP POST binding defines a mechanism by which SAML protocol
|
"""The HTTP POST binding defines a mechanism by which SAML protocol
|
||||||
messages may be transmitted within the base64-encoded content of a
|
messages may be transmitted within the base64-encoded content of a
|
||||||
HTML form control.
|
HTML form control.
|
||||||
@@ -81,19 +89,9 @@ def http_form_post_message(message, location, relay_state="", typ="SAMLRequest")
|
|||||||
|
|
||||||
return {"headers": [("Content-type", "text/html")], "data": response}
|
return {"headers": [("Content-type", "text/html")], "data": response}
|
||||||
|
|
||||||
##noinspection PyUnresolvedReferences
|
|
||||||
#def http_post_message(message, location, relay_state="", typ="SAMLRequest"):
|
|
||||||
# """
|
|
||||||
#
|
|
||||||
# :param message:
|
|
||||||
# :param location:
|
|
||||||
# :param relay_state:
|
|
||||||
# :param typ:
|
|
||||||
# :return:
|
|
||||||
# """
|
|
||||||
# return {"headers": [("Content-type", "text/xml")], "data": message}
|
|
||||||
|
|
||||||
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
def http_redirect_message(message, location, relay_state="", typ="SAMLRequest",
|
||||||
|
sigalg=None, key=None):
|
||||||
"""The HTTP Redirect binding defines a mechanism by which SAML protocol
|
"""The HTTP Redirect binding defines a mechanism by which SAML protocol
|
||||||
messages can be transmitted within URL parameters.
|
messages can be transmitted within URL parameters.
|
||||||
Messages are encoded for use with this binding using a URL encoding
|
Messages are encoded for use with this binding using a URL encoding
|
||||||
@@ -104,13 +102,21 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
|||||||
:param message: The message
|
:param message: The message
|
||||||
:param location: Where the message should be posted to
|
:param location: Where the message should be posted to
|
||||||
:param relay_state: for preserving and conveying state information
|
:param relay_state: for preserving and conveying state information
|
||||||
|
:param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart
|
||||||
|
:param sigalg: The signature algorithm to use.
|
||||||
|
:param key: Key to use for signing
|
||||||
:return: A tuple containing header information and a HTML message.
|
:return: A tuple containing header information and a HTML message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(message, basestring):
|
if not isinstance(message, basestring):
|
||||||
message = "%s" % (message,)
|
message = "%s" % (message,)
|
||||||
|
|
||||||
|
_order = None
|
||||||
if typ in ["SAMLRequest", "SAMLResponse"]:
|
if typ in ["SAMLRequest", "SAMLResponse"]:
|
||||||
|
if typ == "SAMLRequest":
|
||||||
|
_order = REQ_ORDER
|
||||||
|
else:
|
||||||
|
_order = RESP_ORDER
|
||||||
args = {typ: deflate_and_base64_encode(message)}
|
args = {typ: deflate_and_base64_encode(message)}
|
||||||
elif typ == "SAMLart":
|
elif typ == "SAMLart":
|
||||||
args = {typ: message}
|
args = {typ: message}
|
||||||
@@ -120,16 +126,35 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"):
|
|||||||
if relay_state:
|
if relay_state:
|
||||||
args["RelayState"] = relay_state
|
args["RelayState"] = relay_state
|
||||||
|
|
||||||
|
if sigalg:
|
||||||
|
# sigalgs
|
||||||
|
# http://www.w3.org/2000/09/xmldsig#dsa-sha1
|
||||||
|
# http://www.w3.org/2000/09/xmldsig#rsa-sha1
|
||||||
|
|
||||||
|
args["SigAlg"] = sigalg
|
||||||
|
|
||||||
|
if sigalg == RSA_SHA1:
|
||||||
|
signer = RSASigner(sha1_digest, "sha1")
|
||||||
|
string = "&".join([urllib.urlencode({k: args[k]}) for k in _order])
|
||||||
|
args["Signature"] = base64.b64encode(signer.sign(string, key))
|
||||||
|
string = urllib.urlencode(args)
|
||||||
|
else:
|
||||||
|
raise Unsupported("Signing algorithm")
|
||||||
|
else:
|
||||||
|
string = urllib.urlencode(args)
|
||||||
|
|
||||||
glue_char = "&" if urlparse.urlparse(location).query else "?"
|
glue_char = "&" if urlparse.urlparse(location).query else "?"
|
||||||
login_url = glue_char.join([location, urllib.urlencode(args)])
|
login_url = glue_char.join([location, string])
|
||||||
headers = [('Location', login_url)]
|
headers = [('Location', login_url)]
|
||||||
body = []
|
body = []
|
||||||
|
|
||||||
return {"headers":headers, "data":body}
|
return {"headers": headers, "data": body}
|
||||||
|
|
||||||
|
|
||||||
DUMMY_NAMESPACE = "http://example.org/"
|
DUMMY_NAMESPACE = "http://example.org/"
|
||||||
PREFIX = '<?xml version="1.0" encoding="UTF-8"?>'
|
PREFIX = '<?xml version="1.0" encoding="UTF-8"?>'
|
||||||
|
|
||||||
|
|
||||||
def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
|
def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
|
||||||
""" Returns a soap envelope containing a SAML request
|
""" Returns a soap envelope containing a SAML request
|
||||||
as a text string.
|
as a text string.
|
||||||
@@ -170,21 +195,24 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
|
|||||||
cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1]
|
cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1]
|
||||||
_str = _str.replace(cut1, "")
|
_str = _str.replace(cut1, "")
|
||||||
first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],))
|
first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],))
|
||||||
last = _str.find(">", first+14)
|
last = _str.find(">", first + 14)
|
||||||
cut2 = _str[first:last+1]
|
cut2 = _str[first:last + 1]
|
||||||
return _str.replace(cut2, thingy)
|
return _str.replace(cut2, thingy)
|
||||||
else:
|
else:
|
||||||
thingy.become_child_element_of(body)
|
thingy.become_child_element_of(body)
|
||||||
return ElementTree.tostring(envelope, encoding="UTF-8")
|
return ElementTree.tostring(envelope, encoding="UTF-8")
|
||||||
|
|
||||||
|
|
||||||
def http_soap_message(message):
|
def http_soap_message(message):
|
||||||
return {"headers": [("Content-type", "application/soap+xml")],
|
return {"headers": [("Content-type", "application/soap+xml")],
|
||||||
"data": make_soap_enveloped_saml_thingy(message)}
|
"data": make_soap_enveloped_saml_thingy(message)}
|
||||||
|
|
||||||
|
|
||||||
def http_paos(message, extra=None):
|
def http_paos(message, extra=None):
|
||||||
return {"headers":[("Content-type", "application/soap+xml")],
|
return {"headers": [("Content-type", "application/soap+xml")],
|
||||||
"data": make_soap_enveloped_saml_thingy(message, extra)}
|
"data": make_soap_enveloped_saml_thingy(message, extra)}
|
||||||
|
|
||||||
|
|
||||||
def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
||||||
"""Parses a SOAP enveloped SAML thing and returns header parts and body
|
"""Parses a SOAP enveloped SAML thing and returns header parts and body
|
||||||
|
|
||||||
@@ -226,13 +254,15 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
|
|||||||
PACKING = {
|
PACKING = {
|
||||||
saml2.BINDING_HTTP_REDIRECT: http_redirect_message,
|
saml2.BINDING_HTTP_REDIRECT: http_redirect_message,
|
||||||
saml2.BINDING_HTTP_POST: http_form_post_message,
|
saml2.BINDING_HTTP_POST: http_form_post_message,
|
||||||
}
|
}
|
||||||
|
|
||||||
def packager( identifier ):
|
|
||||||
|
def packager(identifier):
|
||||||
try:
|
try:
|
||||||
return PACKING[identifier]
|
return PACKING[identifier]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise Exception("Unkown binding type: %s" % identifier)
|
raise Exception("Unkown binding type: %s" % identifier)
|
||||||
|
|
||||||
|
|
||||||
def factory(binding, message, location, relay_state="", typ="SAMLRequest"):
|
def factory(binding, message, location, relay_state="", typ="SAMLRequest"):
|
||||||
return PACKING[binding](message, location, relay_state, typ)
|
return PACKING[binding](message, location, relay_state, typ)
|
@@ -9,7 +9,7 @@ import sys
|
|||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
# from python 2.5
|
# from python 2.5
|
||||||
if sys.version_info >= (2,5):
|
if sys.version_info >= (2, 5):
|
||||||
import hashlib
|
import hashlib
|
||||||
else: # before python 2.5
|
else: # before python 2.5
|
||||||
import sha
|
import sha
|
||||||
@@ -27,36 +27,51 @@ import zlib
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SamlException(Exception):
|
class SamlException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RequestVersionTooLow(SamlException):
|
class RequestVersionTooLow(SamlException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RequestVersionTooHigh(SamlException):
|
class RequestVersionTooHigh(SamlException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UnknownPrincipal(SamlException):
|
class UnknownPrincipal(SamlException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class UnsupportedBinding(SamlException):
|
|
||||||
|
class Unsupported(SamlException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedBinding(Unsupported):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class VersionMismatch(Exception):
|
class VersionMismatch(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Unknown(Exception):
|
class Unknown(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OtherError(Exception):
|
class OtherError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MissingValue(Exception):
|
class MissingValue(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PolicyError(Exception):
|
class PolicyError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BadRequest(Exception):
|
class BadRequest(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -73,11 +88,12 @@ EXCEPTION2STATUS = {
|
|||||||
Exception: samlp.STATUS_AUTHN_FAILED,
|
Exception: samlp.STATUS_AUTHN_FAILED,
|
||||||
}
|
}
|
||||||
|
|
||||||
GENERIC_DOMAINS = "aero", "asia", "biz", "cat", "com", "coop", \
|
GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu",
|
||||||
"edu", "gov", "info", "int", "jobs", "mil", "mobi", "museum", \
|
"gov", "info", "int", "jobs", "mil", "mobi", "museum",
|
||||||
"name", "net", "org", "pro", "tel", "travel"
|
"name", "net", "org", "pro", "tel", "travel"]
|
||||||
|
|
||||||
def valid_email(emailaddress, domains = GENERIC_DOMAINS):
|
|
||||||
|
def valid_email(emailaddress, domains=GENERIC_DOMAINS):
|
||||||
"""Checks for a syntactically valid email address."""
|
"""Checks for a syntactically valid email address."""
|
||||||
|
|
||||||
# Email address must be at least 6 characters in total.
|
# Email address must be at least 6 characters in total.
|
||||||
@@ -106,23 +122,26 @@ def valid_email(emailaddress, domains = GENERIC_DOMAINS):
|
|||||||
else:
|
else:
|
||||||
return False # Email address has funny characters.
|
return False # Email address has funny characters.
|
||||||
|
|
||||||
def decode_base64_and_inflate( string ):
|
|
||||||
|
def decode_base64_and_inflate(string):
|
||||||
""" base64 decodes and then inflates according to RFC1951
|
""" base64 decodes and then inflates according to RFC1951
|
||||||
|
|
||||||
:param string: a deflated and encoded string
|
:param string: a deflated and encoded string
|
||||||
:return: the string after decoding and inflating
|
:return: the string after decoding and inflating
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return zlib.decompress( base64.b64decode( string ) , -15)
|
return zlib.decompress(base64.b64decode(string), -15)
|
||||||
|
|
||||||
def deflate_and_base64_encode( string_val ):
|
|
||||||
|
def deflate_and_base64_encode(string_val):
|
||||||
"""
|
"""
|
||||||
Deflates and the base64 encodes a string
|
Deflates and the base64 encodes a string
|
||||||
|
|
||||||
:param string_val: The string to deflate and encode
|
:param string_val: The string to deflate and encode
|
||||||
:return: The deflated and encoded string
|
:return: The deflated and encoded string
|
||||||
"""
|
"""
|
||||||
return base64.b64encode( zlib.compress( string_val )[2:-4] )
|
return base64.b64encode(zlib.compress(string_val)[2:-4])
|
||||||
|
|
||||||
|
|
||||||
def rndstr(size=16):
|
def rndstr(size=16):
|
||||||
"""
|
"""
|
||||||
@@ -134,9 +153,11 @@ def rndstr(size=16):
|
|||||||
_basech = string.ascii_letters + string.digits
|
_basech = string.ascii_letters + string.digits
|
||||||
return "".join([random.choice(_basech) for _ in range(size)])
|
return "".join([random.choice(_basech) for _ in range(size)])
|
||||||
|
|
||||||
|
|
||||||
def sid(seed=""):
|
def sid(seed=""):
|
||||||
"""The hash of the server time + seed makes an unique SID for each session.
|
"""The hash of the server time + seed makes an unique SID for each session.
|
||||||
128-bits long so it fulfills the SAML2 requirements which states 128-160 bits
|
128-bits long so it fulfills the SAML2 requirements which states
|
||||||
|
128-160 bits
|
||||||
|
|
||||||
:param seed: A seed string
|
:param seed: A seed string
|
||||||
:return: The hex version of the digest, prefixed by 'id-' to make it
|
:return: The hex version of the digest, prefixed by 'id-' to make it
|
||||||
@@ -146,7 +167,8 @@ def sid(seed=""):
|
|||||||
ident.update(repr(time.time()))
|
ident.update(repr(time.time()))
|
||||||
if seed:
|
if seed:
|
||||||
ident.update(seed)
|
ident.update(seed)
|
||||||
return "id-"+ident.hexdigest()
|
return "id-" + ident.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def parse_attribute_map(filenames):
|
def parse_attribute_map(filenames):
|
||||||
"""
|
"""
|
||||||
@@ -168,6 +190,7 @@ def parse_attribute_map(filenames):
|
|||||||
|
|
||||||
return forward, backward
|
return forward, backward
|
||||||
|
|
||||||
|
|
||||||
def identity_attribute(form, attribute, forward_map=None):
|
def identity_attribute(form, attribute, forward_map=None):
|
||||||
if form == "friendly":
|
if form == "friendly":
|
||||||
if attribute.friendly_name:
|
if attribute.friendly_name:
|
||||||
@@ -182,6 +205,7 @@ def identity_attribute(form, attribute, forward_map=None):
|
|||||||
|
|
||||||
#----------------------------------------------------------------------------
|
#----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def error_status_factory(info):
|
def error_status_factory(info):
|
||||||
if isinstance(info, Exception):
|
if isinstance(info, Exception):
|
||||||
try:
|
try:
|
||||||
@@ -194,32 +218,30 @@ def error_status_factory(info):
|
|||||||
status_code=samlp.StatusCode(
|
status_code=samlp.StatusCode(
|
||||||
value=samlp.STATUS_RESPONDER,
|
value=samlp.STATUS_RESPONDER,
|
||||||
status_code=samlp.StatusCode(
|
status_code=samlp.StatusCode(
|
||||||
value=exc_val)
|
value=exc_val)))
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
(errcode, text) = info
|
(errcode, text) = info
|
||||||
status = samlp.Status(
|
status = samlp.Status(
|
||||||
status_message=samlp.StatusMessage(text=text),
|
status_message=samlp.StatusMessage(text=text),
|
||||||
status_code=samlp.StatusCode(
|
status_code=samlp.StatusCode(
|
||||||
value=samlp.STATUS_RESPONDER,
|
value=samlp.STATUS_RESPONDER,
|
||||||
status_code=samlp.StatusCode(value=errcode)
|
status_code=samlp.StatusCode(value=errcode)))
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
||||||
def success_status_factory():
|
def success_status_factory():
|
||||||
return samlp.Status(status_code=samlp.StatusCode(
|
return samlp.Status(status_code=samlp.StatusCode(
|
||||||
value=samlp.STATUS_SUCCESS))
|
value=samlp.STATUS_SUCCESS))
|
||||||
|
|
||||||
|
|
||||||
def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
|
def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
|
||||||
return samlp.Status(
|
return samlp.Status(
|
||||||
status_message=samlp.StatusMessage(text=message),
|
status_message=samlp.StatusMessage(text=message),
|
||||||
status_code=samlp.StatusCode(
|
status_code=samlp.StatusCode(value=fro,
|
||||||
value=fro,
|
|
||||||
status_code=samlp.StatusCode(value=code)))
|
status_code=samlp.StatusCode(value=code)))
|
||||||
|
|
||||||
|
|
||||||
def assertion_factory(**kwargs):
|
def assertion_factory(**kwargs):
|
||||||
assertion = saml.Assertion(version=VERSION, id=sid(),
|
assertion = saml.Assertion(version=VERSION, id=sid(),
|
||||||
issue_instant=instant())
|
issue_instant=instant())
|
||||||
@@ -227,6 +249,7 @@ def assertion_factory(**kwargs):
|
|||||||
setattr(assertion, key, val)
|
setattr(assertion, key, val)
|
||||||
return assertion
|
return assertion
|
||||||
|
|
||||||
|
|
||||||
def _attrval(val, typ=""):
|
def _attrval(val, typ=""):
|
||||||
if isinstance(val, list) or isinstance(val, set):
|
if isinstance(val, list) or isinstance(val, set):
|
||||||
attrval = [saml.AttributeValue(text=v) for v in val]
|
attrval = [saml.AttributeValue(text=v) for v in val]
|
||||||
@@ -246,6 +269,7 @@ def _attrval(val, typ=""):
|
|||||||
# xmlns:xs="http://www.w3.org/2001/XMLSchema"
|
# xmlns:xs="http://www.w3.org/2001/XMLSchema"
|
||||||
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
|
||||||
|
|
||||||
def do_ava(val, typ=""):
|
def do_ava(val, typ=""):
|
||||||
if isinstance(val, basestring):
|
if isinstance(val, basestring):
|
||||||
ava = saml.AttributeValue()
|
ava = saml.AttributeValue()
|
||||||
@@ -253,7 +277,7 @@ def do_ava(val, typ=""):
|
|||||||
attrval = [ava]
|
attrval = [ava]
|
||||||
elif isinstance(val, list):
|
elif isinstance(val, list):
|
||||||
attrval = [do_ava(v)[0] for v in val]
|
attrval = [do_ava(v)[0] for v in val]
|
||||||
elif val or val == False:
|
elif val or val is False:
|
||||||
ava = saml.AttributeValue()
|
ava = saml.AttributeValue()
|
||||||
ava.set_text(val)
|
ava.set_text(val)
|
||||||
attrval = [ava]
|
attrval = [ava]
|
||||||
@@ -268,6 +292,7 @@ def do_ava(val, typ=""):
|
|||||||
|
|
||||||
return attrval
|
return attrval
|
||||||
|
|
||||||
|
|
||||||
def do_attribute(val, typ, key):
|
def do_attribute(val, typ, key):
|
||||||
attr = saml.Attribute()
|
attr = saml.Attribute()
|
||||||
attrval = do_ava(val, typ)
|
attrval = do_ava(val, typ)
|
||||||
@@ -290,6 +315,7 @@ def do_attribute(val, typ, key):
|
|||||||
attr.friendly_name = friendly
|
attr.friendly_name = friendly
|
||||||
return attr
|
return attr
|
||||||
|
|
||||||
|
|
||||||
def do_attributes(identity):
|
def do_attributes(identity):
|
||||||
attrs = []
|
attrs = []
|
||||||
if not identity:
|
if not identity:
|
||||||
@@ -308,6 +334,7 @@ def do_attributes(identity):
|
|||||||
attrs.append(attr)
|
attrs.append(attr)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
def do_attribute_statement(identity):
|
def do_attribute_statement(identity):
|
||||||
"""
|
"""
|
||||||
:param identity: A dictionary with fiendly names as keys
|
:param identity: A dictionary with fiendly names as keys
|
||||||
@@ -315,12 +342,14 @@ def do_attribute_statement(identity):
|
|||||||
"""
|
"""
|
||||||
return saml.AttributeStatement(attribute=do_attributes(identity))
|
return saml.AttributeStatement(attribute=do_attributes(identity))
|
||||||
|
|
||||||
|
|
||||||
def factory(klass, **kwargs):
|
def factory(klass, **kwargs):
|
||||||
instance = klass()
|
instance = klass()
|
||||||
for key, val in kwargs.items():
|
for key, val in kwargs.items():
|
||||||
setattr(instance, key, val)
|
setattr(instance, key, val)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
def signature(secret, parts):
|
def signature(secret, parts):
|
||||||
"""Generates a signature.
|
"""Generates a signature.
|
||||||
"""
|
"""
|
||||||
@@ -334,6 +363,7 @@ def signature(secret, parts):
|
|||||||
|
|
||||||
return csum.hexdigest()
|
return csum.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def verify_signature(secret, parts):
|
def verify_signature(secret, parts):
|
||||||
""" Checks that the signature is correct """
|
""" Checks that the signature is correct """
|
||||||
if signature(secret, parts[:-1]) == parts[-1]:
|
if signature(secret, parts[:-1]) == parts[-1]:
|
||||||
@@ -344,9 +374,10 @@ def verify_signature(secret, parts):
|
|||||||
|
|
||||||
FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#"
|
FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#"
|
||||||
|
|
||||||
|
|
||||||
def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
|
def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
|
||||||
"""
|
"""
|
||||||
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value ) '#'
|
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value) '#'
|
||||||
Allowed attributes:
|
Allowed attributes:
|
||||||
TS the login time stamp
|
TS the login time stamp
|
||||||
RP the relying party entityID
|
RP the relying party entityID
|
||||||
|
@@ -49,7 +49,7 @@ from saml2.assertion import Policy
|
|||||||
from saml2.assertion import restriction_from_attribute_spec
|
from saml2.assertion import restriction_from_attribute_spec
|
||||||
from saml2.assertion import filter_attribute_value_assertions
|
from saml2.assertion import filter_attribute_value_assertions
|
||||||
|
|
||||||
from saml2.ident import IdentDB
|
from saml2.ident import IdentDB, code
|
||||||
#from saml2.profile import paos
|
#from saml2.profile import paos
|
||||||
from saml2.profile import ecp
|
from saml2.profile import ecp
|
||||||
|
|
||||||
@@ -70,6 +70,8 @@ class Server(Entity):
|
|||||||
self.ticket = {}
|
self.ticket = {}
|
||||||
self.authn = {}
|
self.authn = {}
|
||||||
self.assertion = {}
|
self.assertion = {}
|
||||||
|
self.user2uid = {}
|
||||||
|
self.uid2user = {}
|
||||||
|
|
||||||
def init_config(self, stype="idp"):
|
def init_config(self, stype="idp"):
|
||||||
""" Remaining init of the server configuration
|
""" Remaining init of the server configuration
|
||||||
@@ -194,8 +196,8 @@ class Server(Entity):
|
|||||||
def store_assertion(self, assertion, to_sign):
|
def store_assertion(self, assertion, to_sign):
|
||||||
self.assertion[assertion.id] = (assertion, to_sign)
|
self.assertion[assertion.id] = (assertion, to_sign)
|
||||||
|
|
||||||
def get_assertion(self, id):
|
def get_assertion(self, cid):
|
||||||
return self.assertion[id]
|
return self.assertion[cid]
|
||||||
|
|
||||||
def store_authn_statement(self, authn_statement, name_id):
|
def store_authn_statement(self, authn_statement, name_id):
|
||||||
"""
|
"""
|
||||||
@@ -204,7 +206,8 @@ class Server(Entity):
|
|||||||
:param name_id:
|
:param name_id:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
nkey = sha1("%s" % name_id).hexdigest()
|
logger.debug("store authn about: %s" % name_id)
|
||||||
|
nkey = sha1(code(name_id)).hexdigest()
|
||||||
logger.debug("Store authn_statement under key: %s" % nkey)
|
logger.debug("Store authn_statement under key: %s" % nkey)
|
||||||
try:
|
try:
|
||||||
self.authn[nkey].append(authn_statement)
|
self.authn[nkey].append(authn_statement)
|
||||||
@@ -221,8 +224,14 @@ class Server(Entity):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
key = sha1("%s" % name_id).hexdigest()
|
key = sha1(code(name_id)).hexdigest()
|
||||||
for statement in self.authn[key]:
|
try:
|
||||||
|
statements = self.authn[key]
|
||||||
|
except KeyError:
|
||||||
|
logger.info("Unknown subject %s" % name_id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
for statement in statements:
|
||||||
if session_index:
|
if session_index:
|
||||||
if statement.session_index != session_index:
|
if statement.session_index != session_index:
|
||||||
continue
|
continue
|
||||||
@@ -234,7 +243,8 @@ class Server(Entity):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def remove_authn_statements(self, name_id):
|
def remove_authn_statements(self, name_id):
|
||||||
nkey = sha1("%s" % name_id).hexdigest()
|
logger.debug("remove authn about: %s" % name_id)
|
||||||
|
nkey = sha1(code(name_id)).hexdigest()
|
||||||
|
|
||||||
del self.authn[nkey]
|
del self.authn[nkey]
|
||||||
|
|
||||||
|
@@ -21,11 +21,15 @@ Based on the use of xmlsec1 binaries and not the python xmlsec module.
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from time import mktime
|
||||||
|
import urllib
|
||||||
import M2Crypto
|
import M2Crypto
|
||||||
|
from M2Crypto.X509 import load_cert_string
|
||||||
from saml2.samlp import Response
|
from saml2.samlp import Response
|
||||||
|
|
||||||
import xmldsig as ds
|
import xmldsig as ds
|
||||||
@@ -37,8 +41,11 @@ from saml2 import ExtensionElement
|
|||||||
from saml2 import VERSION
|
from saml2 import VERSION
|
||||||
|
|
||||||
from saml2.s_utils import sid
|
from saml2.s_utils import sid
|
||||||
|
from saml2.s_utils import Unsupported
|
||||||
|
|
||||||
from saml2.time_util import instant
|
from saml2.time_util import instant
|
||||||
|
from saml2.time_util import utc_now
|
||||||
|
from saml2.time_util import str_to_time
|
||||||
|
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from subprocess import Popen, PIPE
|
from subprocess import Popen, PIPE
|
||||||
@@ -47,6 +54,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature")
|
SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature")
|
||||||
|
|
||||||
|
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
|
||||||
|
|
||||||
|
|
||||||
def signed(item):
|
def signed(item):
|
||||||
if SIG in item.c_children.keys() and item.signature:
|
if SIG in item.c_children.keys() and item.signature:
|
||||||
return True
|
return True
|
||||||
@@ -62,6 +72,7 @@ def signed(item):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_xmlsec_binary(paths=None):
|
def get_xmlsec_binary(paths=None):
|
||||||
"""
|
"""
|
||||||
Tries to find the xmlsec1 binary.
|
Tries to find the xmlsec1 binary.
|
||||||
@@ -109,45 +120,34 @@ ENC_KEY_CLASS = "EncryptedKey"
|
|||||||
|
|
||||||
_TEST_ = True
|
_TEST_ = True
|
||||||
|
|
||||||
|
|
||||||
class SignatureError(Exception):
|
class SignatureError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class XmlsecError(Exception):
|
class XmlsecError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MissingKey(Exception):
|
class MissingKey(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DecryptError(Exception):
|
class DecryptError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# --------------------------------------------------------------------------
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
#def make_signed_instance(klass, spec, seccont, base64encode=False):
|
|
||||||
# """ Will only return signed instance if the signature
|
|
||||||
# preamble is present
|
|
||||||
#
|
|
||||||
# :param klass: The type of class the instance should be
|
|
||||||
# :param spec: The specification of attributes and children of the class
|
|
||||||
# :param seccont: The security context (instance of SecurityContext)
|
|
||||||
# :param base64encode: Whether the attribute values should be base64 encoded
|
|
||||||
# :return: A signed (or possibly unsigned) instance of the class
|
|
||||||
# """
|
|
||||||
# if "signature" in spec:
|
|
||||||
# signed_xml = seccont.sign_statement_using_xmlsec("%s" % instance,
|
|
||||||
# class_name(instance), instance.id)
|
|
||||||
# return create_class_from_xml_string(instance.__class__, signed_xml)
|
|
||||||
# else:
|
|
||||||
# return make_instance(klass, spec, base64encode)
|
|
||||||
|
|
||||||
def xmlsec_version(execname):
|
def xmlsec_version(execname):
|
||||||
com_list = [execname,"--version"]
|
com_list = [execname, "--version"]
|
||||||
pof = Popen(com_list, stderr=PIPE, stdout=PIPE)
|
pof = Popen(com_list, stderr=PIPE, stdout=PIPE)
|
||||||
try:
|
try:
|
||||||
return pof.stdout.read().split(" ")[1]
|
return pof.stdout.read().split(" ")[1]
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
||||||
base64encode=False, elements_to_sign=None):
|
base64encode=False, elements_to_sign=None):
|
||||||
"""
|
"""
|
||||||
@@ -188,6 +188,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False,
|
|||||||
cis = [cinst]
|
cis = [cinst]
|
||||||
setattr(klass_inst, prop, cis)
|
setattr(klass_inst, prop, cis)
|
||||||
|
|
||||||
|
|
||||||
def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
||||||
instance = klass()
|
instance = klass()
|
||||||
|
|
||||||
@@ -208,7 +209,8 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
|||||||
#print "## %s, %s" % (prop, klassdef)
|
#print "## %s, %s" % (prop, klassdef)
|
||||||
if prop in ava:
|
if prop in ava:
|
||||||
#print "### %s" % ava[prop]
|
#print "### %s" % ava[prop]
|
||||||
if isinstance(klassdef, list): # means there can be a list of values
|
if isinstance(klassdef, list):
|
||||||
|
# means there can be a list of values
|
||||||
_make_vals(ava[prop], klassdef[0], seccont, instance, prop,
|
_make_vals(ava[prop], klassdef[0], seccont, instance, prop,
|
||||||
base64encode=base64encode,
|
base64encode=base64encode,
|
||||||
elements_to_sign=elements_to_sign)
|
elements_to_sign=elements_to_sign)
|
||||||
@@ -226,12 +228,12 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
|
|||||||
for key, val in ava["extension_attributes"].items():
|
for key, val in ava["extension_attributes"].items():
|
||||||
instance.extension_attributes[key] = val
|
instance.extension_attributes[key] = val
|
||||||
|
|
||||||
|
|
||||||
if "signature" in ava:
|
if "signature" in ava:
|
||||||
elements_to_sign.append((class_name(instance), instance.id))
|
elements_to_sign.append((class_name(instance), instance.id))
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -243,8 +245,8 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
|||||||
if elements_to_sign:
|
if elements_to_sign:
|
||||||
signed_xml = "%s" % instance
|
signed_xml = "%s" % instance
|
||||||
for (node_name, nodeid) in elements_to_sign:
|
for (node_name, nodeid) in elements_to_sign:
|
||||||
signed_xml = seccont.sign_statement_using_xmlsec(signed_xml,
|
signed_xml = seccont.sign_statement_using_xmlsec(
|
||||||
klass_namn=node_name, nodeid=nodeid)
|
signed_xml, klass_namn=node_name, nodeid=nodeid)
|
||||||
|
|
||||||
#print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
#print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
#print "%s" % signed_xml
|
#print "%s" % signed_xml
|
||||||
@@ -255,6 +257,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
|
|||||||
|
|
||||||
# --------------------------------------------------------------------------
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def create_id():
|
def create_id():
|
||||||
""" Create a string of 40 random characters from the set [a-p],
|
""" Create a string of 40 random characters from the set [a-p],
|
||||||
can be used as a unique identifier of objects.
|
can be used as a unique identifier of objects.
|
||||||
@@ -266,6 +269,7 @@ def create_id():
|
|||||||
ret += chr(random.randint(0, 15) + ord('a'))
|
ret += chr(random.randint(0, 15) + ord('a'))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def make_temp(string, suffix="", decode=True):
|
def make_temp(string, suffix="", decode=True):
|
||||||
""" xmlsec needs files in some cases where only strings exist, hence the
|
""" xmlsec needs files in some cases where only strings exist, hence the
|
||||||
need for this function. It creates a temporary file with the
|
need for this function. It creates a temporary file with the
|
||||||
@@ -288,10 +292,35 @@ def make_temp(string, suffix="", decode=True):
|
|||||||
ntf.seek(0)
|
ntf.seek(0)
|
||||||
return ntf, ntf.name
|
return ntf, ntf.name
|
||||||
|
|
||||||
def split_len(seq, length):
|
|
||||||
return [seq[i:i+length] for i in range(0, len(seq), length)]
|
|
||||||
|
|
||||||
def cert_from_key_info(key_info):
|
def split_len(seq, length):
|
||||||
|
return [seq[i:i + length] for i in range(0, len(seq), length)]
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
M2_TIME_FORMAT = "%b %d %H:%M:%S %Y"
|
||||||
|
|
||||||
|
|
||||||
|
def to_time(_time):
|
||||||
|
assert _time.endswith(" GMT")
|
||||||
|
_time = _time[:-4]
|
||||||
|
return mktime(str_to_time(_time, M2_TIME_FORMAT))
|
||||||
|
|
||||||
|
|
||||||
|
def active_cert(key):
|
||||||
|
cert_str = pem_format(key)
|
||||||
|
certificate = load_cert_string(cert_str)
|
||||||
|
not_before = to_time(str(certificate.get_not_before()))
|
||||||
|
not_after = to_time(str(certificate.get_not_after()))
|
||||||
|
try:
|
||||||
|
assert not_before < utc_now()
|
||||||
|
assert not_after > utc_now()
|
||||||
|
return True
|
||||||
|
except AssertionError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def cert_from_key_info(key_info, ignore_age=False):
|
||||||
""" Get all X509 certs from a KeyInfo instance. Care is taken to make sure
|
""" Get all X509 certs from a KeyInfo instance. Care is taken to make sure
|
||||||
that the certs are continues sequences of bytes.
|
that the certs are continues sequences of bytes.
|
||||||
|
|
||||||
@@ -307,12 +336,16 @@ def cert_from_key_info(key_info):
|
|||||||
#print "X509Data",x509_data
|
#print "X509Data",x509_data
|
||||||
x509_certificate = x509_data.x509_certificate
|
x509_certificate = x509_data.x509_certificate
|
||||||
cert = x509_certificate.text.strip()
|
cert = x509_certificate.text.strip()
|
||||||
cert = "\n".join(split_len("".join([
|
cert = "\n".join(split_len("".join([s.strip() for s in
|
||||||
s.strip() for s in cert.split()]),64))
|
cert.split()]), 64))
|
||||||
|
if ignore_age or active_cert(cert):
|
||||||
res.append(cert)
|
res.append(cert)
|
||||||
|
else:
|
||||||
|
logger.info("Inactive cert")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def cert_from_key_info_dict(key_info):
|
|
||||||
|
def cert_from_key_info_dict(key_info, ignore_age=False):
|
||||||
""" Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure
|
""" Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure
|
||||||
that the certs are continues sequences of bytes.
|
that the certs are continues sequences of bytes.
|
||||||
|
|
||||||
@@ -330,11 +363,15 @@ def cert_from_key_info_dict(key_info):
|
|||||||
for x509_data in key_info["x509_data"]:
|
for x509_data in key_info["x509_data"]:
|
||||||
x509_certificate = x509_data["x509_certificate"]
|
x509_certificate = x509_data["x509_certificate"]
|
||||||
cert = x509_certificate["text"].strip()
|
cert = x509_certificate["text"].strip()
|
||||||
cert = "\n".join(split_len("".join([
|
cert = "\n".join(split_len("".join([s.strip() for s in
|
||||||
s.strip() for s in cert.split()]),64))
|
cert.split()]), 64))
|
||||||
|
if ignore_age or active_cert(cert):
|
||||||
res.append(cert)
|
res.append(cert)
|
||||||
|
else:
|
||||||
|
logger.info("Inactive cert")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def cert_from_instance(instance):
|
def cert_from_instance(instance):
|
||||||
""" Find certificates that are part of an instance
|
""" Find certificates that are part of an instance
|
||||||
|
|
||||||
@@ -343,19 +380,23 @@ def cert_from_instance(instance):
|
|||||||
"""
|
"""
|
||||||
if instance.signature:
|
if instance.signature:
|
||||||
if instance.signature.key_info:
|
if instance.signature.key_info:
|
||||||
return cert_from_key_info(instance.signature.key_info)
|
return cert_from_key_info(instance.signature.key_info,
|
||||||
|
ignore_age=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
from M2Crypto.__m2crypto import bn_to_mpi
|
from M2Crypto.__m2crypto import bn_to_mpi
|
||||||
from M2Crypto.__m2crypto import hex_to_bn
|
from M2Crypto.__m2crypto import hex_to_bn
|
||||||
|
|
||||||
|
|
||||||
def intarr2long(arr):
|
def intarr2long(arr):
|
||||||
return long(''.join(["%02x" % byte for byte in arr]), 16)
|
return long(''.join(["%02x" % byte for byte in arr]), 16)
|
||||||
|
|
||||||
|
|
||||||
def dehexlify(bi):
|
def dehexlify(bi):
|
||||||
s = hexlify(bi)
|
s = hexlify(bi)
|
||||||
return [int(s[i]+s[i+1], 16) for i in range(0,len(s),2)]
|
return [int(s[i] + s[i + 1], 16) for i in range(0, len(s), 2)]
|
||||||
|
|
||||||
|
|
||||||
def long_to_mpi(num):
|
def long_to_mpi(num):
|
||||||
"""Converts a python integer or long to OpenSSL MPInt used by M2Crypto.
|
"""Converts a python integer or long to OpenSSL MPInt used by M2Crypto.
|
||||||
@@ -365,10 +406,12 @@ def long_to_mpi(num):
|
|||||||
h = '0' + h # add leading 0 to get even number of hexdigits
|
h = '0' + h # add leading 0 to get even number of hexdigits
|
||||||
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
|
return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum
|
||||||
|
|
||||||
|
|
||||||
def base64_to_long(data):
|
def base64_to_long(data):
|
||||||
_d = base64.urlsafe_b64decode(data + '==')
|
_d = base64.urlsafe_b64decode(data + '==')
|
||||||
return intarr2long(dehexlify(_d))
|
return intarr2long(dehexlify(_d))
|
||||||
|
|
||||||
|
|
||||||
def key_from_key_value(key_info):
|
def key_from_key_value(key_info):
|
||||||
res = []
|
res = []
|
||||||
for value in key_info.key_value:
|
for value in key_info.key_value:
|
||||||
@@ -380,6 +423,7 @@ def key_from_key_value(key_info):
|
|||||||
res.append(key)
|
res.append(key)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def key_from_key_value_dict(key_info):
|
def key_from_key_value_dict(key_info):
|
||||||
res = []
|
res = []
|
||||||
if not "key_value" in key_info:
|
if not "key_value" in key_info:
|
||||||
@@ -396,9 +440,27 @@ def key_from_key_value_dict(key_info):
|
|||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def rsa_load(filename):
|
||||||
|
"""Read a PEM-encoded RSA key pair from a file."""
|
||||||
|
return M2Crypto.RSA.load_key(filename, M2Crypto.util.no_passphrase_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def rsa_loads(key):
|
||||||
|
"""Read a PEM-encoded RSA key pair from a string."""
|
||||||
|
return M2Crypto.RSA.load_key_string(key,
|
||||||
|
M2Crypto.util.no_passphrase_callback)
|
||||||
|
|
||||||
|
|
||||||
|
def x509_rsa_loads(string):
|
||||||
|
cert = M2Crypto.X509.load_cert_string(string)
|
||||||
|
return cert.get_pubkey().get_rsa()
|
||||||
|
|
||||||
|
|
||||||
def pem_format(key):
|
def pem_format(key):
|
||||||
return "\n".join(["-----BEGIN CERTIFICATE-----",
|
return "\n".join(["-----BEGIN CERTIFICATE-----",
|
||||||
key,"-----END CERTIFICATE-----"])
|
key, "-----END CERTIFICATE-----"])
|
||||||
|
|
||||||
|
|
||||||
def parse_xmlsec_output(output):
|
def parse_xmlsec_output(output):
|
||||||
""" Parse the output from xmlsec to try to find out if the
|
""" Parse the output from xmlsec to try to find out if the
|
||||||
@@ -416,8 +478,80 @@ def parse_xmlsec_output(output):
|
|||||||
|
|
||||||
__DEBUG = 0
|
__DEBUG = 0
|
||||||
|
|
||||||
LOG_LINE = 60*"="+"\n%s\n"+60*"-"+"\n%s"+60*"="
|
|
||||||
LOG_LINE_2 = 60*"="+"\n%s\n%s\n"+60*"-"+"\n%s"+60*"="
|
class BadSignature(Exception):
|
||||||
|
"""The signature is invalid."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def sha1_digest(msg):
|
||||||
|
return hashlib.sha1(msg).digest()
|
||||||
|
|
||||||
|
|
||||||
|
class Signer(object):
|
||||||
|
"""Abstract base class for signing algorithms."""
|
||||||
|
def sign(self, msg, key):
|
||||||
|
"""Sign ``msg`` with ``key`` and return the signature."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def verify(self, msg, sig, key):
|
||||||
|
"""Return True if ``sig`` is a valid signature for ``msg``."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class RSASigner(Signer):
|
||||||
|
def __init__(self, digest, algo):
|
||||||
|
self.digest = digest
|
||||||
|
self.algo = algo
|
||||||
|
|
||||||
|
def sign(self, msg, key):
|
||||||
|
return key.sign(self.digest(msg), self.algo)
|
||||||
|
|
||||||
|
def verify(self, msg, sig, key):
|
||||||
|
try:
|
||||||
|
return key.verify(self.digest(msg), sig, self.algo)
|
||||||
|
except M2Crypto.RSA.RSAError, e:
|
||||||
|
raise BadSignature(e)
|
||||||
|
|
||||||
|
|
||||||
|
REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"]
|
||||||
|
RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"]
|
||||||
|
|
||||||
|
|
||||||
|
def verify_redirect_signature(info, cert):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param info: A dictionary as produced by parse_qs, means all values are
|
||||||
|
lists.
|
||||||
|
:param cert: A certificate to use when verifying the signature
|
||||||
|
:return: True, if signature verified
|
||||||
|
"""
|
||||||
|
|
||||||
|
if info["SigAlg"][0] == RSA_SHA1:
|
||||||
|
if "SAMLRequest" in info:
|
||||||
|
_order = REQ_ORDER
|
||||||
|
elif "SAMLResponse" in info:
|
||||||
|
_order = RESP_ORDER
|
||||||
|
else:
|
||||||
|
raise Unsupported(
|
||||||
|
"Verifying signature on something that should not be signed")
|
||||||
|
signer = RSASigner(sha1_digest, "sha1")
|
||||||
|
args = info.copy()
|
||||||
|
del args["Signature"] # everything but the signature
|
||||||
|
string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order])
|
||||||
|
_key = x509_rsa_loads(pem_format(cert))
|
||||||
|
_sign = base64.b64decode(info["Signature"][0])
|
||||||
|
try:
|
||||||
|
signer.verify(string, _sign, _key)
|
||||||
|
return True
|
||||||
|
except BadSignature:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise Unsupported("Signature algorithm: %s" % info["SigAlg"])
|
||||||
|
|
||||||
|
LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||||
|
LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
|
||||||
|
|
||||||
|
|
||||||
def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
|
def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
|
||||||
node_name=NODE_NAME, debug=False, node_id=None,
|
node_name=NODE_NAME, debug=False, node_id=None,
|
||||||
@@ -481,6 +615,7 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem",
|
|||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def read_cert_from_file(cert_file, cert_type):
|
def read_cert_from_file(cert_file, cert_type):
|
||||||
""" Reads a certificate from a file. The assumption is that there is
|
""" Reads a certificate from a file. The assumption is that there is
|
||||||
only one certificate in the file
|
only one certificate in the file
|
||||||
@@ -516,6 +651,7 @@ def read_cert_from_file(cert_file, cert_type):
|
|||||||
data = open(cert_file).read()
|
data = open(cert_file).read()
|
||||||
return base64.b64encode(str(data))
|
return base64.b64encode(str(data))
|
||||||
|
|
||||||
|
|
||||||
def security_context(conf, debug=None):
|
def security_context(conf, debug=None):
|
||||||
""" Creates a security context based on the configuration
|
""" Creates a security context based on the configuration
|
||||||
|
|
||||||
@@ -538,8 +674,9 @@ def security_context(conf, debug=None):
|
|||||||
cert_file=conf.cert_file, metadata=metadata,
|
cert_file=conf.cert_file, metadata=metadata,
|
||||||
debug=debug, only_use_keys_in_metadata=_only_md)
|
debug=debug, only_use_keys_in_metadata=_only_md)
|
||||||
|
|
||||||
|
|
||||||
class SecurityContext(object):
|
class SecurityContext(object):
|
||||||
def __init__(self, xmlsec_binary, key_file="", key_type= "pem",
|
def __init__(self, xmlsec_binary, key_file="", key_type="pem",
|
||||||
cert_file="", cert_type="pem", metadata=None,
|
cert_file="", cert_type="pem", metadata=None,
|
||||||
debug=False, template="", encrypt_key_type="des-192",
|
debug=False, template="", encrypt_key_type="des-192",
|
||||||
only_use_keys_in_metadata=False):
|
only_use_keys_in_metadata=False):
|
||||||
@@ -592,12 +729,9 @@ class SecurityContext(object):
|
|||||||
_, fil = make_temp("%s" % text, decode=False)
|
_, fil = make_temp("%s" % text, decode=False)
|
||||||
ntf = NamedTemporaryFile()
|
ntf = NamedTemporaryFile()
|
||||||
|
|
||||||
com_list = [self.xmlsec, "--encrypt",
|
com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key,
|
||||||
"--pubkey-pem", recv_key,
|
"--session-key", key_type, "--xml-data", fil,
|
||||||
"--session-key", key_type,
|
"--output", ntf.name, template]
|
||||||
"--xml-data", fil,
|
|
||||||
"--output", ntf.name,
|
|
||||||
template]
|
|
||||||
|
|
||||||
logger.debug("Encryption command: %s" % " ".join(com_list))
|
logger.debug("Encryption command: %s" % " ".join(com_list))
|
||||||
|
|
||||||
@@ -625,11 +759,9 @@ class SecurityContext(object):
|
|||||||
_, fil = make_temp("%s" % enctext, decode=False)
|
_, fil = make_temp("%s" % enctext, decode=False)
|
||||||
ntf = NamedTemporaryFile()
|
ntf = NamedTemporaryFile()
|
||||||
|
|
||||||
com_list = [self.xmlsec, "--decrypt",
|
com_list = [self.xmlsec, "--decrypt", "--privkey-pem",
|
||||||
"--privkey-pem", self.key_file,
|
self.key_file, "--output", ntf.name,
|
||||||
"--output", ntf.name,
|
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, fil]
|
||||||
"--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS,
|
|
||||||
fil]
|
|
||||||
|
|
||||||
logger.debug("Decrypt command: %s" % " ".join(com_list))
|
logger.debug("Decrypt command: %s" % " ".join(com_list))
|
||||||
|
|
||||||
@@ -646,7 +778,6 @@ class SecurityContext(object):
|
|||||||
ntf.seek(0)
|
ntf.seek(0)
|
||||||
return ntf.read()
|
return ntf.read()
|
||||||
|
|
||||||
|
|
||||||
def verify_signature(self, enctext, cert_file=None, cert_type="pem",
|
def verify_signature(self, enctext, cert_file=None, cert_type="pem",
|
||||||
node_name=NODE_NAME, node_id=None, id_attr=""):
|
node_name=NODE_NAME, node_id=None, id_attr=""):
|
||||||
""" Verifies the signature of a XML document.
|
""" Verifies the signature of a XML document.
|
||||||
@@ -889,11 +1020,9 @@ class SecurityContext(object):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
#--------------------------------------------------------------------------
|
#--------------------------------------------------------------------------
|
||||||
# SIGNATURE PART
|
# SIGNATURE PART
|
||||||
#--------------------------------------------------------------------------
|
#--------------------------------------------------------------------------
|
||||||
|
|
||||||
def sign_statement_using_xmlsec(self, statement, klass_namn, key=None,
|
def sign_statement_using_xmlsec(self, statement, klass_namn, key=None,
|
||||||
key_file=None, nodeid=None, id_attr=""):
|
key_file=None, nodeid=None, id_attr=""):
|
||||||
"""Sign a SAML statement using xmlsec.
|
"""Sign a SAML statement using xmlsec.
|
||||||
@@ -917,7 +1046,6 @@ class SecurityContext(object):
|
|||||||
|
|
||||||
_, fil = make_temp("%s" % statement, decode=False)
|
_, fil = make_temp("%s" % statement, decode=False)
|
||||||
|
|
||||||
|
|
||||||
ntf = NamedTemporaryFile()
|
ntf = NamedTemporaryFile()
|
||||||
|
|
||||||
com_list = [self.xmlsec, "--sign",
|
com_list = [self.xmlsec, "--sign",
|
||||||
@@ -975,10 +1103,8 @@ class SecurityContext(object):
|
|||||||
:return: The signed statement
|
:return: The signed statement
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.sign_statement_using_xmlsec(statement,
|
return self.sign_statement_using_xmlsec(statement, class_name(
|
||||||
class_name(samlp.AttributeQuery()),
|
samlp.AttributeQuery()), key, key_file, nodeid, id_attr=id_attr)
|
||||||
key, key_file, nodeid,
|
|
||||||
id_attr=id_attr)
|
|
||||||
|
|
||||||
def multiple_signatures(self, statement, to_sign, key=None, key_file=None):
|
def multiple_signatures(self, statement, to_sign, key=None, key_file=None):
|
||||||
"""
|
"""
|
||||||
@@ -991,15 +1117,15 @@ class SecurityContext(object):
|
|||||||
:param key_file: A file that contains the key to be used
|
:param key_file: A file that contains the key to be used
|
||||||
:return: A possibly multiple signed statement
|
:return: A possibly multiple signed statement
|
||||||
"""
|
"""
|
||||||
for (item, id, id_attr) in to_sign:
|
for (item, sid, id_attr) in to_sign:
|
||||||
if not id:
|
if not sid:
|
||||||
if not item.id:
|
if not item.id:
|
||||||
id = item.id = sid()
|
sid = item.id = sid()
|
||||||
else:
|
else:
|
||||||
id = item.id
|
sid = item.id
|
||||||
|
|
||||||
if not item.signature:
|
if not item.signature:
|
||||||
item.signature = pre_signature_part(id, self.cert_file)
|
item.signature = pre_signature_part(sid, self.cert_file)
|
||||||
|
|
||||||
statement = self.sign_statement_using_xmlsec(statement,
|
statement = self.sign_statement_using_xmlsec(statement,
|
||||||
class_name(item),
|
class_name(item),
|
||||||
@@ -1024,22 +1150,20 @@ def pre_signature_part(ident, public_key=None, identifier=None):
|
|||||||
:return: A preset signature part
|
:return: A preset signature part
|
||||||
"""
|
"""
|
||||||
|
|
||||||
signature_method = ds.SignatureMethod(algorithm = ds.SIG_RSA_SHA1)
|
signature_method = ds.SignatureMethod(algorithm=ds.SIG_RSA_SHA1)
|
||||||
canonicalization_method = ds.CanonicalizationMethod(
|
canonicalization_method = ds.CanonicalizationMethod(
|
||||||
algorithm = ds.ALG_EXC_C14N)
|
algorithm=ds.ALG_EXC_C14N)
|
||||||
trans0 = ds.Transform(algorithm = ds.TRANSFORM_ENVELOPED)
|
trans0 = ds.Transform(algorithm=ds.TRANSFORM_ENVELOPED)
|
||||||
trans1 = ds.Transform(algorithm = ds.ALG_EXC_C14N)
|
trans1 = ds.Transform(algorithm=ds.ALG_EXC_C14N)
|
||||||
transforms = ds.Transforms(transform = [trans0, trans1])
|
transforms = ds.Transforms(transform=[trans0, trans1])
|
||||||
digest_method = ds.DigestMethod(algorithm = ds.DIGEST_SHA1)
|
digest_method = ds.DigestMethod(algorithm=ds.DIGEST_SHA1)
|
||||||
|
|
||||||
reference = ds.Reference(uri = "#%s" % ident,
|
reference = ds.Reference(uri="#%s" % ident, digest_value=ds.DigestValue(),
|
||||||
digest_value = ds.DigestValue(),
|
transforms=transforms, digest_method=digest_method)
|
||||||
transforms = transforms,
|
|
||||||
digest_method = digest_method)
|
|
||||||
|
|
||||||
signed_info = ds.SignedInfo(signature_method = signature_method,
|
signed_info = ds.SignedInfo(signature_method=signature_method,
|
||||||
canonicalization_method = canonicalization_method,
|
canonicalization_method=canonicalization_method,
|
||||||
reference = reference)
|
reference=reference)
|
||||||
|
|
||||||
signature = ds.Signature(signed_info=signed_info,
|
signature = ds.Signature(signed_info=signed_info,
|
||||||
signature_value=ds.SignatureValue())
|
signature_value=ds.SignatureValue())
|
||||||
@@ -1048,8 +1172,8 @@ def pre_signature_part(ident, public_key=None, identifier=None):
|
|||||||
signature.id = "Signature%d" % identifier
|
signature.id = "Signature%d" % identifier
|
||||||
|
|
||||||
if public_key:
|
if public_key:
|
||||||
x509_data = ds.X509Data(x509_certificate=[ds.X509Certificate(
|
x509_data = ds.X509Data(
|
||||||
text=public_key)])
|
x509_certificate=[ds.X509Certificate(text=public_key)])
|
||||||
key_info = ds.KeyInfo(x509_data=x509_data)
|
key_info = ds.KeyInfo(x509_data=x509_data)
|
||||||
signature.key_info = key_info
|
signature.key_info = key_info
|
||||||
|
|
||||||
|
@@ -149,13 +149,13 @@ def add_duration(tid, duration):
|
|||||||
if days < 1:
|
if days < 1:
|
||||||
pass
|
pass
|
||||||
elif days > maximum_day_in_month_for(year, month):
|
elif days > maximum_day_in_month_for(year, month):
|
||||||
days = days - maximum_day_in_month_for(year, month)
|
days -= maximum_day_in_month_for(year, month)
|
||||||
carry = 1
|
carry = 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
temp = month + carry
|
temp = month + carry
|
||||||
month = modulo(temp, 1, 13)
|
month = modulo(temp, 1, 13)
|
||||||
year = year + f_quotient(temp, 1, 13)
|
year += f_quotient(temp, 1, 13)
|
||||||
|
|
||||||
return time.localtime(time.mktime((year, month, days, hour, minutes,
|
return time.localtime(time.mktime((year, month, days, hour, minutes,
|
||||||
secs, 0, 0, -1)))
|
secs, 0, 0, -1)))
|
||||||
@@ -232,7 +232,7 @@ def str_to_time(timestr, format=TIME_FORMAT):
|
|||||||
return 0
|
return 0
|
||||||
try:
|
try:
|
||||||
then = time.strptime(timestr, format)
|
then = time.strptime(timestr, format)
|
||||||
except Exception: # assume it's a format problem
|
except ValueError, err: # assume it's a format problem
|
||||||
try:
|
try:
|
||||||
elem = TIME_FORMAT_WITH_FRAGMENT.match(timestr)
|
elem = TIME_FORMAT_WITH_FRAGMENT.match(timestr)
|
||||||
except Exception, exc:
|
except Exception, exc:
|
||||||
|
44
tests/test_70_redirect_signing.py
Normal file
44
tests/test_70_redirect_signing.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from saml2.pack import http_redirect_message
|
||||||
|
from saml2.sigver import verify_redirect_signature
|
||||||
|
from saml2.sigver import RSA_SHA1
|
||||||
|
from saml2.server import Server
|
||||||
|
from saml2 import BINDING_HTTP_REDIRECT
|
||||||
|
from saml2.client import Saml2Client
|
||||||
|
from saml2.config import SPConfig
|
||||||
|
from saml2.sigver import rsa_load
|
||||||
|
from urlparse import parse_qs
|
||||||
|
|
||||||
|
__author__ = 'rolandh'
|
||||||
|
|
||||||
|
idp = Server(config_file="idp_all_conf")
|
||||||
|
|
||||||
|
conf = SPConfig()
|
||||||
|
conf.load_file("servera_conf")
|
||||||
|
sp = Saml2Client(conf)
|
||||||
|
|
||||||
|
def test():
|
||||||
|
srvs = sp.metadata.single_sign_on_service(idp.config.entityid,
|
||||||
|
BINDING_HTTP_REDIRECT)
|
||||||
|
|
||||||
|
destination = srvs[0]["location"]
|
||||||
|
req = sp.create_authn_request(destination, id="id1")
|
||||||
|
|
||||||
|
try:
|
||||||
|
key = sp.sec.key
|
||||||
|
except AttributeError:
|
||||||
|
key = rsa_load(sp.sec.key_file)
|
||||||
|
|
||||||
|
info = http_redirect_message(req, destination, relay_state="RS",
|
||||||
|
typ="SAMLRequest", sigalg=RSA_SHA1, key=key)
|
||||||
|
|
||||||
|
verified_ok = False
|
||||||
|
|
||||||
|
for param, val in info["headers"]:
|
||||||
|
if param == "Location":
|
||||||
|
_dict = parse_qs(val.split("?")[1])
|
||||||
|
_certs = idp.metadata.certs(sp.config.entityid, "any", "signing")
|
||||||
|
for cert in _certs:
|
||||||
|
if verify_redirect_signature(_dict, cert):
|
||||||
|
verified_ok = True
|
||||||
|
|
||||||
|
assert verified_ok
|
Reference in New Issue
Block a user