From 9bbbfc29625aab6638428ef712aad21e42f51685 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Mon, 11 Feb 2013 14:32:35 +0100 Subject: [PATCH 1/8] Made log out work in the sp+idp2 example . --- example/idp2/idp.py | 43 +++++++---- example/sp/sp.py | 14 ++-- example/sp/sp_conf.py | 26 +++---- example/sp/who.ini | 3 +- src/s2repoze/plugins/sp.py | 152 ++++++++++++++++++------------------- src/saml2/client.py | 6 +- src/saml2/entity.py | 8 +- src/saml2/server.py | 14 ++-- 8 files changed, 140 insertions(+), 126 deletions(-) diff --git a/example/idp2/idp.py b/example/idp2/idp.py index e78cee9..6ba5883 100755 --- a/example/idp2/idp.py +++ b/example/idp2/idp.py @@ -31,6 +31,7 @@ from saml2.saml import AUTHN_PASSWORD logger = logging.getLogger("saml2.idp") + def _expiration(timeout, tformat="%a, %d-%b-%Y %H:%M:%S GMT"): """ @@ -50,7 +51,7 @@ def _expiration(timeout, tformat="%a, %d-%b-%Y %H:%M:%S GMT"): def dict2list_of_tuples(d): - return [(k,v) for k,v in d.items()] + return [(k, v) for k, v in d.items()] # ----------------------------------------------------------------------------- @@ -65,7 +66,7 @@ class Service(object): def unpack_redirect(self): if "QUERY_STRING" in self.environ: _qs = self.environ["QUERY_STRING"] - return dict([(k,v[0]) for k,v in parse_qs(_qs).items()]) + return dict([(k, v[0]) for k, v in parse_qs(_qs).items()]) else: return None @@ -270,8 +271,13 @@ class SSO(Service): IDP.ticket[key] = _dict _resp = key else: - _resp = IDP.ticket[_dict["key"]] - del IDP.ticket[_dict["key"]] + try: + _resp = IDP.ticket[_dict["key"]] + del IDP.ticket[_dict["key"]] + except KeyError: + key = sha1("%s" % _dict).hexdigest() + IDP.ticket[key] = _dict + _resp = key return _resp @@ -391,8 +397,9 @@ def do_verify(environ, start_response, _): if not _ok: resp = Unauthorized("Unknown user or wrong password") else: - uid = rndstr() - IDP.authn[uid] = user + uid = rndstr(24) + IDP.uid2user[uid] = user + IDP.user2uid[user] = uid logger.debug("Register %s under '%s'" % (user, uid)) kaka = set_cookie("idpauthn", "/", uid) lox = "http://%s%s?id=%s&key=%s" % (environ["HTTP_HOST"], @@ -437,6 +444,8 @@ class SLO(Service): if msg.name_id: lid = IDP.ident.find_local_id(msg.name_id) logger.info("local identifier: %s" % lid) + del IDP.uid2user[IDP.user2uid[lid]] + del IDP.user2uid[lid] # remove the authentication try: IDP.remove_authn_statements(msg.name_id) @@ -445,7 +454,7 @@ class SLO(Service): resp = ServiceError("%s" % exc) return resp(self.environ, self.start_response) - resp = IDP.create_logout_response(msg) + resp = IDP.create_logout_response(msg, [binding]) try: hinfo = IDP.apply_binding(binding, "%s" % resp, "", relay_state) @@ -454,11 +463,11 @@ class SLO(Service): resp = ServiceError("%s" % exc) return resp(self.environ, self.start_response) - logger.info("Header: %s" % (hinfo["headers"],)) #_tlh = dict2list_of_tuples(hinfo["headers"]) delco = delete_cookie(self.environ, "idpauthn") if delco: hinfo["headers"].append(delco) + logger.info("Header: %s" % (hinfo["headers"],)) resp = Response(hinfo["data"], headers=hinfo["headers"]) return resp(self.environ, self.start_response) @@ -475,9 +484,9 @@ class NMI(Service): request = req.message # Do the necessary stuff - name_id = IDP.ident.handle_manage_name_id_request(request.name_id, - request.new_id, request.new_encrypted_id, - request.terminate) + name_id = IDP.ident.handle_manage_name_id_request( + request.name_id, request.new_id, request.new_encrypted_id, + request.terminate) logger.debug("New NameID: %s" % name_id) @@ -606,8 +615,8 @@ class NIM(Service): request = req.message # Do the necessary stuff try: - name_id = IDP.ident.handle_name_id_mapping_request(request.name_id, - request.name_id_policy) + name_id = IDP.ident.handle_name_id_mapping_request( + request.name_id, request.name_id_policy) except Unknown: resp = BadRequest("Unknown entity") return resp(self.environ, self.start_response) @@ -638,7 +647,10 @@ def kaka2user(kaka): cookie_obj = SimpleCookie(kaka) morsel = cookie_obj.get("idpauthn", None) if morsel: - return IDP.authn[morsel.value] + try: + return IDP.uid2user[morsel.value] + except KeyError: + return None else: logger.debug("No idpauthn cookie") return None @@ -646,6 +658,7 @@ def kaka2user(kaka): def delete_cookie(environ, name): kaka = environ.get("HTTP_COOKIE", '') + logger.debug("delete KAKA: %s" % kaka) if kaka: cookie_obj = SimpleCookie(kaka) morsel = cookie_obj.get(name, None) @@ -739,7 +752,7 @@ def application(environ, start_response): try: query = parse_qs(environ["QUERY_STRING"]) logger.debug("QUERY: %s" % query) - user = IDP.authn[query["id"][0]] + user = IDP.uid2user[query["id"][0]] except KeyError: user = None diff --git a/example/sp/sp.py b/example/sp/sp.py index d57a1d6..9e5b63a 100755 --- a/example/sp/sp.py +++ b/example/sp/sp.py @@ -145,8 +145,10 @@ def logout(environ, start_response, user): # What if more than one _dict = client.saml_client.global_logout(subject_id) logger.info("[logout] global_logout > %s" % (_dict,)) + rem = environ['repoze.who.plugins'][client.rememberer_name] + rem.forget(environ, subject_id) - for key, item in _dict.item(): + for key, item in _dict.items(): if isinstance(item, tuple): binding, htargs = item else: # result from logout, should be OK @@ -200,13 +202,15 @@ def application(environ, start_response): request is done :return: The response as a list of lines """ + path = environ.get('PATH_INFO', '').lstrip('/') + logger.info(" PATH: %s" % path) + user = environ.get("REMOTE_USER", "") if not user: user = environ.get("repoze.who.identity", "") - - path = environ.get('PATH_INFO', '').lstrip('/') - logger.info(" PATH: %s" % path) - logger.info("logger name: %s" % logger.name) + logger.info("repoze.who.identity: '%s'" % user) + else: + logger.info("REMOTE_USER: '%s'" % user) #logger.info(logging.Logger.manager.loggerDict) for regex, callback in urls: if user: diff --git a/example/sp/sp_conf.py b/example/sp/sp_conf.py index 8411599..b8223df 100644 --- a/example/sp/sp_conf.py +++ b/example/sp/sp_conf.py @@ -5,14 +5,14 @@ BASE= "http://localhost:8087" #BASE= "http://lingon.catalogix.se:8087" CONFIG = { - "entityid" : "%s/sp.xml" % BASE, + "entityid": "%s/sp.xml" % BASE, "description": "My SP", "service": { - "sp":{ + "sp": { "name" : "Rolands SP", - "endpoints":{ + "endpoints": { "assertion_consumer_service": [BASE], - "single_logout_service" : [(BASE+"/slo", + "single_logout_service": [(BASE + "/slo", BINDING_HTTP_REDIRECT)], }, "required_attributes": ["surname", "givenname", @@ -20,18 +20,16 @@ CONFIG = { "optional_attributes": ["title"], } }, - "debug" : 1, - "key_file" : "pki/mykey.pem", - "cert_file" : "pki/mycert.pem", - "attribute_map_dir" : "./attributemaps", - "metadata" : { - "local": ["../idp2/idp.xml"], - }, + "debug": 1, + "key_file": "pki/mykey.pem", + "cert_file": "pki/mycert.pem", + "attribute_map_dir": "./attributemaps", + "metadata": {"local": ["../idp2/idp.xml"]}, # -- below used by make_metadata -- "organization": { "name": "Exempel AB", - "display_name": [("Exempel AB","se"),("Example Co.","en")], - "url":"http://www.example.com/roland", + "display_name": [("Exempel AB", "se"), ("Example Co.", "en")], + "url": "http://www.example.com/roland", }, "contact_person": [{ "given_name":"John", @@ -47,7 +45,7 @@ CONFIG = { "filename": "sp.log", "maxBytes": 100000, "backupCount": 5, - }, + }, "loglevel": "debug", } } diff --git a/example/sp/who.ini b/example/sp/who.ini index 77ea206..cd4c3e1 100644 --- a/example/sp/who.ini +++ b/example/sp/who.ini @@ -14,9 +14,8 @@ reissue_time = 3000 [plugin:saml2auth] use = s2repoze.plugins.sp:make_plugin saml_conf = sp_conf -rememberer_name = auth_tkt +remember_name = auth_tkt sid_store = outstanding -identity_cache = identities [general] request_classifier = s2repoze.plugins.challenge_decider:my_request_classifier diff --git a/src/s2repoze/plugins/sp.py b/src/s2repoze/plugins/sp.py index 567c02b..94ab3f9 100644 --- a/src/s2repoze/plugins/sp.py +++ b/src/s2repoze/plugins/sp.py @@ -42,6 +42,7 @@ from saml2 import ecp, BINDING_HTTP_REDIRECT from saml2 import BINDING_HTTP_POST from saml2.client import Saml2Client +from saml2.ident import code, decode from saml2.s_utils import sid from saml2.config import config_factory from saml2.profile import paos @@ -53,16 +54,18 @@ logger = logging.getLogger(__name__) PAOS_HEADER_INFO = 'ver="%s";"%s"' % (paos.NAMESPACE, ECP_SERVICE) + def construct_came_from(environ): """ The URL that the user used when the process where interupted for single-sign-on processing. """ came_from = environ.get("PATH_INFO") - qstr = environ.get("QUERY_STRING","") + qstr = environ.get("QUERY_STRING", "") if qstr: came_from += '?' + qstr return came_from + # FormPluginBase defines the methods remember and forget def cgi_field_storage_to_dict(field_storage): """Get a plain dictionary, rather than the '.value' system used by the @@ -71,16 +74,15 @@ def cgi_field_storage_to_dict(field_storage): params = {} for key in field_storage.keys(): try: - params[ key ] = field_storage[ key ].value + params[key] = field_storage[key].value except AttributeError: - if isinstance(field_storage[ key ], basestring): + if isinstance(field_storage[key], basestring): params[key] = field_storage[key] return params -def get_body(environ): - body = "" +def get_body(environ): length = int(environ["CONTENT_LENGTH"]) try: body = environ["wsgi.input"].read(length) @@ -95,11 +97,13 @@ def get_body(environ): return body + def exception_trace(tag, exc, log): message = traceback.format_exception(*sys.exc_info()) log.error("[%s] ExcList: %s" % (tag, "".join(message),)) log.error("[%s] Exception: %s" % (tag, exc)) + class ECP_response(object): code = 200 title = 'OK' @@ -114,13 +118,12 @@ class ECP_response(object): return [self.content] - class SAML2Plugin(FormPluginBase): implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider) - def __init__(self, rememberer_name, config, saml_client, - wayf, cache, debug, sid_store=None, discovery=""): + def __init__(self, rememberer_name, config, saml_client, wayf, cache, + debug, sid_store=None, discovery=""): FormPluginBase.__init__(self) self.rememberer_name = rememberer_name @@ -148,8 +151,6 @@ class SAML2Plugin(FormPluginBase): :param environ: A dictionary with environment variables """ - post = {} - post_env = environ.copy() post_env['QUERY_STRING'] = '' @@ -173,8 +174,8 @@ class SAML2Plugin(FormPluginBase): sid_ = sid() self.outstanding_queries[sid_] = came_from logger.info("Redirect to WAYF function: %s" % self.wayf) - return -1, HTTPSeeOther(headers = [('Location', - "%s?%s" % (self.wayf, sid_))]) + return -1, HTTPSeeOther(headers=[('Location', + "%s?%s" % (self.wayf, sid_))]) #noinspection PyUnusedLocal def _pick_idp(self, environ, came_from): @@ -189,6 +190,8 @@ class SAML2Plugin(FormPluginBase): # 'PAOS' : 'ver="%s";"%s"' % (paos.NAMESPACE, SERVICE) # } + _cli = self.saml_client + logger.info("[_pick_idp] %s" % environ) if "HTTP_PAOS" in environ: if environ["HTTP_PAOS"] == PAOS_HEADER_INFO: @@ -200,28 +203,26 @@ class SAML2Plugin(FormPluginBase): logger.info("- ECP client detected -") _relay_state = construct_came_from(environ) - _entityid = self.saml_client.config.ecp_endpoint( - environ["REMOTE_ADDR"]) + _entityid = _cli.config.ecp_endpoint(environ["REMOTE_ADDR"]) + if not _entityid: return -1, HTTPInternalServerError( - detail="No IdP to talk to" - ) + detail="No IdP to talk to") logger.info("IdP to talk to: %s" % _entityid) - return ecp.ecp_auth_request(self.saml_client, _entityid, + return ecp.ecp_auth_request(_cli, _entityid, _relay_state) else: return -1, HTTPInternalServerError( - detail='Faulty Accept header') + detail='Faulty Accept header') else: return -1, HTTPInternalServerError( - detail='unknown ECP version') - + detail='unknown ECP version') idps = self.metadata.with_descriptor("idpsso") logger.info("IdP URL: %s" % idps) - if len( idps ) == 1: + if len(idps) == 1: # idps is a dictionary idp_entity_id = idps.keys()[0] elif not len(idps): @@ -229,16 +230,17 @@ class SAML2Plugin(FormPluginBase): else: idp_entity_id = "" logger.info("ENVIRON: %s" % environ) - query = environ.get('s2repoze.body','') + query = environ.get('s2repoze.body', '') if not query: - query = environ.get("QUERY_STRING","") + query = environ.get("QUERY_STRING", "") logger.info("<_pick_idp> query: %s" % query) if self.wayf: if query: try: - wayf_selected = dict(parse_qs(query))["wayf_selected"][0] + wayf_selected = dict(parse_qs(query))[ + "wayf_selected"][0] except KeyError: return self._wayf_redirect(came_from) idp_entity_id = wayf_selected @@ -246,16 +248,16 @@ class SAML2Plugin(FormPluginBase): return self._wayf_redirect(came_from) elif self.discosrv: if query: - idp_entity_id = self.saml_client.parse_discovery_service_response( - query=environ.get("QUERY_STRING")) + idp_entity_id = _cli.parse_discovery_service_response( + query=environ.get("QUERY_STRING")) else: sid_ = sid() self.outstanding_queries[sid_] = came_from logger.info("Redirect to Discovery Service function") - eid = self.saml_client.config.entity_id - loc = self.saml_client.create_discovery_service_request( - self.discosrv, eid) - return -1, HTTPSeeOther(headers = [('Location',loc)]) + eid = _cli.config.entity_id + loc = _cli.create_discovery_service_request(self.discosrv, + eid) + return -1, HTTPSeeOther(headers=[('Location', loc)]) else: return -1, HTTPNotImplemented(detail='No WAYF or DJ present!') @@ -266,8 +268,10 @@ class SAML2Plugin(FormPluginBase): #noinspection PyUnusedLocal def challenge(self, environ, _status, _app_headers, _forget_headers): - # this challenge consist in login out - if environ.has_key('rwpc.logout'): + _cli = self.saml_client + + # this challenge consist in logging out + if 'rwpc.logout' in environ: # ignore right now? pass @@ -283,7 +287,7 @@ class SAML2Plugin(FormPluginBase): vorg_name = environ["myapp.vo"] except KeyError: try: - vorg_name = self.saml_client.vorg._name + vorg_name = _cli.vorg._name except AttributeError: vorg_name = "" @@ -304,7 +308,6 @@ class SAML2Plugin(FormPluginBase): entity_id = response logger.info("[sp.challenge] entity_id: %s" % entity_id) # Do the AuthnRequest - _cli = self.saml_client _binding = BINDING_HTTP_REDIRECT try: srvs = _cli.metadata.single_sign_on_service(entity_id, _binding) @@ -319,7 +322,8 @@ class SAML2Plugin(FormPluginBase): logger.debug("ht_args: %s" % ht_args) except Exception, exc: logger.exception(exc) - raise Exception("Failed to construct the AuthnRequest: %s" % exc) + raise Exception( + "Failed to construct the AuthnRequest: %s" % exc) # remember the request self.outstanding_queries[_sid] = came_from @@ -327,14 +331,15 @@ class SAML2Plugin(FormPluginBase): if not ht_args["data"] and ht_args["headers"][0][0] == "Location": logger.debug('redirect to: %s' % ht_args["headers"][0][1]) return HTTPSeeOther(headers=ht_args["headers"]) - else : + else: return ht_args["data"] def _construct_identity(self, session_info): + cni = code(session_info["name_id"]) identity = { - "login": session_info["name_id"], + "login": cni, "password": "", - 'repoze.who.userid': session_info["name_id"], + 'repoze.who.userid': cni, "user": session_info["ava"], } logger.debug("Identity: %s" % identity) @@ -349,9 +354,8 @@ class SAML2Plugin(FormPluginBase): # Evaluate the response, returns a AuthnResponse instance try: authresp = self.saml_client.parse_authn_request_response( - post["SAMLResponse"], - BINDING_HTTP_POST, - self.outstanding_queries) + post["SAMLResponse"], BINDING_HTTP_POST, + self.outstanding_queries) except Exception, excp: logger.exception("Exception: %s" % (excp,)) raise @@ -413,7 +417,7 @@ class SAML2Plugin(FormPluginBase): pass try: - if not post.has_key("SAMLResponse"): + if "SAMLResponse" not in post: logger.info("[sp.identify] --- NOT SAMLResponse ---") # Not for me, put the post back where next in line can # find it @@ -424,8 +428,8 @@ class SAML2Plugin(FormPluginBase): # check for SAML2 authN response #if self.debug: try: - session_info = self._eval_authn_response(environ, - cgi_field_storage_to_dict(post)) + session_info = self._eval_authn_response( + environ, cgi_field_storage_to_dict(post)) except Exception: return None except TypeError, exc: @@ -445,37 +449,28 @@ class SAML2Plugin(FormPluginBase): if session_info: environ["s2repoze.sessioninfo"] = session_info - name_id = session_info["name_id"] - # contruct and return the identity - identity = { - "login": name_id, - "password": "", - 'repoze.who.userid': name_id, - "user": self.saml_client.users.get_identity(name_id)[0], - } - logger.info("[sp.identify] IDENTITY: %s" % (identity,)) - return identity + return self._construct_identity(session_info) else: return None - # IMetadataProvider def add_metadata(self, environ, identity): """ Add information to the knowledge I have about the user """ - subject_id = identity['repoze.who.userid'] - #logger = environ.get('repoze.who.logger','') + name_id = identity['repoze.who.userid'] + if isinstance(name_id, basestring): + name_id = decode(name_id) _cli = self.saml_client - logger.debug("[add_metadata] for %s" % subject_id) + logger.debug("[add_metadata] for %s" % name_id) try: - logger.debug("Issuers: %s" % _cli.users.sources(subject_id)) + logger.debug("Issuers: %s" % _cli.users.sources(name_id)) except KeyError: pass if "user" not in identity: identity["user"] = {} try: - (ava, _) = _cli.users.get_identity(subject_id) + (ava, _) = _cli.users.get_identity(name_id) #now = time.gmtime() logger.debug("[add_metadata] adds: %s" % ava) identity["user"].update(ava) @@ -486,9 +481,9 @@ class SAML2Plugin(FormPluginBase): # is this a Virtual Organization situation for vo in _cli.vorg.values(): try: - if vo.do_aggregation(subject_id): + if vo.do_aggregation(name_id): # Get the extended identity - identity["user"] = _cli.users.get_identity(subject_id)[0] + identity["user"] = _cli.users.get_identity(name_id)[0] # Only do this once, mark that the identity has been # expanded identity["pysaml2_vo_expanded"] = 1 @@ -505,7 +500,7 @@ class SAML2Plugin(FormPluginBase): # used 2 times : one to get the ticket, the other to validate it def _service_url(self, environ, qstr=None): if qstr is not None: - url = construct_url(environ, querystring = qstr) + url = construct_url(environ, querystring=qstr) else: url = construct_url(environ) return url @@ -519,32 +514,29 @@ class SAML2Plugin(FormPluginBase): return None -def make_plugin(rememberer_name=None, # plugin for remember - cache= "", # cache - # Which virtual organization to support - virtual_organization="", - saml_conf="", - wayf="", - sid_store="", - identity_cache="", - discovery="", - ): +def make_plugin(remember_name=None, # plugin for remember + cache="", # cache + # Which virtual organization to support + virtual_organization="", + saml_conf="", + wayf="", + sid_store="", + identity_cache="", + discovery="", + ): if saml_conf is "": raise ValueError( 'must include saml_conf in configuration') - if rememberer_name is None: - raise ValueError( - 'must include rememberer_name in configuration') + if remember_name is None: + raise ValueError('must include remember_name in configuration') conf = config_factory("sp", saml_conf) scl = Saml2Client(config=conf, identity_cache=identity_cache, - virtual_organization=virtual_organization) + virtual_organization=virtual_organization) - plugin = SAML2Plugin(rememberer_name, conf, scl, wayf, cache, sid_store, + plugin = SAML2Plugin(remember_name, conf, scl, wayf, cache, sid_store, discovery) return plugin - - diff --git a/src/saml2/client.py b/src/saml2/client.py index f7a7939..ccab7b1 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -18,6 +18,7 @@ """Contains classes and functions that a SAML2.0 Service Provider (SP) may use to conclude its tasks. """ +from saml2.ident import decode from saml2.httpbase import HTTPError from saml2.s_utils import sid import saml2 @@ -99,11 +100,14 @@ class Saml2Client(Base): conversation. """ + if isinstance(name_id, basestring): + name_id = decode(name_id) + logger.info("logout request for: %s" % name_id) # find out which IdPs/AAs I should notify entity_ids = self.users.issuers_of_info(name_id) - + self.users.remove_person(name_id) return self.do_logout(name_id, entity_ids, reason, expire, sign) def do_logout(self, name_id, entity_ids, reason, expire, sign=None): diff --git a/src/saml2/entity.py b/src/saml2/entity.py index 962fb1a..de241c6 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -204,11 +204,11 @@ class Entity(HTTPBase): raise Exception("Unkown entity or unsupported bindings") - def message_args(self, id=0): - if not id: - id = sid(self.seed) + def message_args(self, seid=0): + if not seid: + seid = sid(self.seed) - return {"id":id, "version":VERSION, + return {"id":seid, "version":VERSION, "issue_instant":instant(), "issuer":self._issuer()} def response_args(self, message, bindings=None, descr_type=""): diff --git a/src/saml2/server.py b/src/saml2/server.py index 17d7117..6bc04d0 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -49,7 +49,7 @@ from saml2.assertion import Policy from saml2.assertion import restriction_from_attribute_spec from saml2.assertion import filter_attribute_value_assertions -from saml2.ident import IdentDB +from saml2.ident import IdentDB, code #from saml2.profile import paos from saml2.profile import ecp @@ -70,6 +70,8 @@ class Server(Entity): self.ticket = {} self.authn = {} self.assertion = {} + self.user2uid = {} + self.uid2user = {} def init_config(self, stype="idp"): """ Remaining init of the server configuration @@ -194,8 +196,8 @@ class Server(Entity): def store_assertion(self, assertion, to_sign): self.assertion[assertion.id] = (assertion, to_sign) - def get_assertion(self, id): - return self.assertion[id] + def get_assertion(self, cid): + return self.assertion[cid] def store_authn_statement(self, authn_statement, name_id): """ @@ -204,7 +206,8 @@ class Server(Entity): :param name_id: :return: """ - nkey = sha1("%s" % name_id).hexdigest() + logger.debug("store authn about: %s" % name_id) + nkey = sha1(code(name_id)).hexdigest() logger.debug("Store authn_statement under key: %s" % nkey) try: self.authn[nkey].append(authn_statement) @@ -234,7 +237,8 @@ class Server(Entity): return result def remove_authn_statements(self, name_id): - nkey = sha1("%s" % name_id).hexdigest() + logger.debug("remove authn about: %s" % name_id) + nkey = sha1(code(name_id)).hexdigest() del self.authn[nkey] From 0d45a31259944ecb81964ec704b287697c4fb80a Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sun, 17 Feb 2013 20:50:45 +0100 Subject: [PATCH 2/8] Request HTTP response contains content not error. Bug in key calculation. --- src/saml2/httpbase.py | 2 +- src/saml2/server.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py index 9643e83..26aa551 100644 --- a/src/saml2/httpbase.py +++ b/src/saml2/httpbase.py @@ -304,7 +304,7 @@ class HTTPBase(object): logger.info("SOAP response: %s" % response.text) return response else: - raise HTTPError("%d:%s" % (response.status_code, response.error)) + raise HTTPError("%d:%s" % (response.status_code, response.content)) def add_credentials(self, user, passwd): self.user = user diff --git a/src/saml2/server.py b/src/saml2/server.py index 6bc04d0..b53d6c6 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -224,8 +224,14 @@ class Server(Entity): :return: """ result = [] - key = sha1("%s" % name_id).hexdigest() - for statement in self.authn[key]: + key = sha1(code(name_id)).hexdigest() + try: + statements = self.authn[key] + except KeyError: + logger.info("Unknown subject %s" % name_id) + return [] + + for statement in statements: if session_index: if statement.session_index != session_index: continue From 487eef79fc39ebf46f6dc4e112482a043ce36f6e Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 19 Feb 2013 09:42:12 +0100 Subject: [PATCH 3/8] Improved cookie handling added debug line --- src/saml2/httpbase.py | 69 +++++++++++++++++++++++++++++++------------ src/saml2/mdstore.py | 1 + 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py index 26aa551..11e369d 100644 --- a/src/saml2/httpbase.py +++ b/src/saml2/httpbase.py @@ -19,8 +19,8 @@ logger = logging.getLogger(__name__) __author__ = 'rolandh' -ATTRS = {"version":None, - "name":"", +ATTRS = {"version": None, + "name": "", "value": None, "port": None, "port_specified": False, @@ -43,19 +43,22 @@ PAIRS = { "path": "path_specified" } + class ConnectionError(Exception): pass + class HTTPError(Exception): pass + def _since_epoch(cdate): """ :param cdate: date format 'Wed, 06-Jun-2012 01:34:34 GMT' :return: UTC time """ - if len(cdate) < 29: # somethings broken + if len(cdate) < 29: # somethings broken if len(cdate) < 5: return utc_now() @@ -67,16 +70,19 @@ def _since_epoch(cdate): #return int(time.mktime(t)) return calendar.timegm(t) + def set_list2dict(sl): return dict(sl) + def dict2set_list(dic): - return [(k,v) for k,v in dic.items()] + return [(k, v) for k, v in dic.items()] + class HTTPBase(object): def __init__(self, verify=True, ca_bundle=None, key_file=None, cert_file=None): - self.request_args = {"allow_redirects": False,} + self.request_args = {"allow_redirects": False} #self.cookies = {} self.cookiejar = cookielib.CookieJar() @@ -98,6 +104,12 @@ class HTTPBase(object): :return: """ part = urlparse.urlparse(url) + + #if part.port: + # _domain = "%s:%s" % (part.hostname, part.port) + #else: + _domain = part.netloc + cookie_dict = {} now = utc_now() for _, a in list(self.cookiejar._cookies.items()): @@ -106,6 +118,8 @@ class HTTPBase(object): # print cookie if cookie.expires and cookie.expires <= now: continue + if not re.search("%s$" % cookie.domain, _domain): + continue if not re.match(cookie.path, part.path): continue @@ -116,8 +130,13 @@ class HTTPBase(object): def set_cookie(self, kaka, request): """Returns a cookielib.Cookie based on a set-cookie header line""" - # default rfc2109=False - # max-age, httponly + if not kaka: + return + + logger.debug("Set Cookie '%s'" % kaka) + part = urlparse.urlparse(request.url) + _domain = part.netloc + for cookie_name, morsel in kaka.items(): std_attr = ATTRS.copy() std_attr["name"] = cookie_name @@ -133,9 +152,9 @@ class HTTPBase(object): if attr in ATTRS: if morsel[attr]: if attr == "expires": - std_attr[attr]=_since_epoch(morsel[attr]) + std_attr[attr] = _since_epoch(morsel[attr]) else: - std_attr[attr]=morsel[attr] + std_attr[attr] = morsel[attr] elif attr == "max-age": if morsel["max-age"]: std_attr["expires"] = _since_epoch(morsel["max-age"]) @@ -144,8 +163,12 @@ class HTTPBase(object): if std_attr[att]: std_attr[item] = True - if std_attr["domain"] and std_attr["domain"].startswith("."): - std_attr["domain_initial_dot"] = True + if std_attr["domain"]: + if std_attr["domain"].startswith("."): + std_attr["domain_initial_dot"] = True + else: + std_attr["domain"] = _domain + std_attr["domain_specified"] = True if morsel["max-age"] is 0: try: @@ -154,6 +177,13 @@ class HTTPBase(object): name=std_attr["name"]) except ValueError: pass + elif morsel["expires"] < utc_now(): + try: + self.cookiejar.clear(domain=std_attr["domain"], + path=std_attr["path"], + name=std_attr["name"]) + except ValueError: + pass else: new_cookie = cookielib.Cookie(**std_attr) self.cookiejar.set_cookie(new_cookie) @@ -164,21 +194,22 @@ class HTTPBase(object): _kwargs.update(kwargs) if self.cookiejar: - _kwargs["cookies"] = self.cookies(url) + _cd = self.cookies(url) + if _cd: + _kwargs["cookies"] = _cd + logger.debug("Sent cookies: %s" % _kwargs["cookies"]) if self.user and self.passwd: - _kwargs["auth"]= (self.user, self.passwd) + _kwargs["auth"] = (self.user, self.passwd) - #logger.info("SENT COOKIEs: %s" % (_kwargs["cookies"],)) try: r = requests.request(method, url, **_kwargs) except requests.ConnectionError, exc: raise ConnectionError("%s" % exc) try: - #logger.info("RECEIVED COOKIEs: %s" % (r.headers["set-cookie"],)) self.set_cookie(SimpleCookie(r.headers["set-cookie"]), r) - except AttributeError, err: + except AttributeError: pass return r @@ -196,7 +227,7 @@ class HTTPBase(object): :return: dictionary """ if not isinstance(message, basestring): - request = "%s" % (message,) + message = "%s" % (message,) return http_form_post_message(message, destination, relay_state, typ) @@ -213,7 +244,7 @@ class HTTPBase(object): :return: dictionary """ if not isinstance(message, basestring): - request = "%s" % (message,) + message = "%s" % (message,) return http_redirect_message(message, destination, relay_state, typ) @@ -278,7 +309,7 @@ class HTTPBase(object): soap_message = _signed return {"url": destination, "method": "POST", - "data":soap_message, "headers":headers} + "data": soap_message, "headers": headers} def send_using_soap(self, request, destination, headers=None, sign=False): """ diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index c756078..509c603 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -202,6 +202,7 @@ class MetaData(object): res[srv["binding"]].append(srv) except KeyError: res[srv["binding"]] = [srv] + logger.debug("_service => %s" % res) return res def _ext_service(self, entity_id, typ, service, binding): From 114a015bfafee32283c27e9a7d44f854e06fe7af Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Tue, 19 Feb 2013 13:05:52 +0100 Subject: [PATCH 4/8] port not part of cookie domain. --- src/saml2/httpbase.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py index 11e369d..977b509 100644 --- a/src/saml2/httpbase.py +++ b/src/saml2/httpbase.py @@ -108,7 +108,7 @@ class HTTPBase(object): #if part.port: # _domain = "%s:%s" % (part.hostname, part.port) #else: - _domain = part.netloc + _domain = part.hostname cookie_dict = {} now = utc_now() @@ -135,7 +135,7 @@ class HTTPBase(object): logger.debug("Set Cookie '%s'" % kaka) part = urlparse.urlparse(request.url) - _domain = part.netloc + _domain = part.hostname for cookie_name, morsel in kaka.items(): std_attr = ATTRS.copy() From 1ade66f38cd694218fe3735e0a609e1c3a6334b5 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Wed, 20 Feb 2013 14:47:53 +0100 Subject: [PATCH 5/8] Moved log line for better info. --- src/saml2/httpbase.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py index 977b509..9410bfe 100644 --- a/src/saml2/httpbase.py +++ b/src/saml2/httpbase.py @@ -133,9 +133,9 @@ class HTTPBase(object): if not kaka: return - logger.debug("Set Cookie '%s'" % kaka) part = urlparse.urlparse(request.url) _domain = part.hostname + logger.debug("%s: '%s'" % (_domain, kaka)) for cookie_name, morsel in kaka.items(): std_attr = ATTRS.copy() From ed8b6953bb1421cca3b7637b2b625447b9f3038b Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Wed, 20 Feb 2013 14:48:12 +0100 Subject: [PATCH 6/8] editorial --- src/saml2/time_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/saml2/time_util.py b/src/saml2/time_util.py index 07a5075..57fdddf 100644 --- a/src/saml2/time_util.py +++ b/src/saml2/time_util.py @@ -149,16 +149,16 @@ def add_duration(tid, duration): if days < 1: pass elif days > maximum_day_in_month_for(year, month): - days = days - maximum_day_in_month_for(year, month) + days -= maximum_day_in_month_for(year, month) carry = 1 else: break temp = month + carry month = modulo(temp, 1, 13) - year = year + f_quotient(temp, 1, 13) + year += f_quotient(temp, 1, 13) - return time.localtime(time.mktime((year, month, days, hour, minutes, - secs, 0, 0, -1))) + return time.localtime(time.mktime((year, month, days, hour, minutes, + secs, 0, 0, -1))) else: pass @@ -232,7 +232,7 @@ def str_to_time(timestr, format=TIME_FORMAT): return 0 try: then = time.strptime(timestr, format) - except Exception: # assume it's a format problem + except ValueError, err: # assume it's a format problem try: elem = TIME_FORMAT_WITH_FRAGMENT.match(timestr) except Exception, exc: From cc71990164832ae37dd90f780124eb147997093b Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 21 Feb 2013 12:36:05 +0100 Subject: [PATCH 7/8] Added support for signing/verifying messages when using the HTTP-Redirect binding. --- src/saml2/entity.py | 21 ++-- src/saml2/pack.py | 131 +++++++++++++++++++-- src/saml2/s_utils.py | 93 ++++++++++----- src/saml2/sigver.py | 269 +++++++++++++++++++++++++------------------ 4 files changed, 350 insertions(+), 164 deletions(-) diff --git a/src/saml2/entity.py b/src/saml2/entity.py index de241c6..205f6d7 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -208,8 +208,8 @@ class Entity(HTTPBase): if not seid: seid = sid(self.seed) - return {"id":seid, "version":VERSION, - "issue_instant":instant(), "issuer":self._issuer()} + return {"id": seid, "version": VERSION, + "issue_instant": instant(), "issuer": self._issuer()} def response_args(self, message, bindings=None, descr_type=""): info = {"in_response_to": message.id} @@ -278,10 +278,9 @@ class Entity(HTTPBase): :param text: The SOAP message :return: A dictionary with two keys "body" and "header" """ - return class_instances_from_soap_enveloped_saml_thingies(text, - [paos, - ecp, - samlp]) + return class_instances_from_soap_enveloped_saml_thingies(text, [paos, + ecp, + samlp]) def unpack_soap_message(self, text): """ @@ -367,8 +366,8 @@ class Entity(HTTPBase): _issuer = self._issuer(issuer) response = response_factory(issuer=_issuer, - in_response_to = in_response_to, - status = status) + in_response_to=in_response_to, + status=status) if consumer_url: response.destination = consumer_url @@ -377,7 +376,7 @@ class Entity(HTTPBase): setattr(response, key, val) if sign: - self.sign(response,to_sign=to_sign) + self.sign(response, to_sign=to_sign) elif to_sign: return signed_instance_factory(response, self.sec, to_sign) else: @@ -689,8 +688,8 @@ class Entity(HTTPBase): if binding in [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST]: try: # expected return address - kwargs["return_addr"] = self.config.endpoint(service, - binding=binding)[0] + kwargs["return_addr"] = self.config.endpoint( + service, binding=binding)[0] except Exception: logger.info("Not supposed to handle this!") return None diff --git a/src/saml2/pack.py b/src/saml2/pack.py index a985e7a..4cf5996 100644 --- a/src/saml2/pack.py +++ b/src/saml2/pack.py @@ -23,12 +23,15 @@ Bindings normally consists of three parts: - how to package the information - which protocol to use """ +import hashlib import urlparse import saml2 import base64 import urllib -from saml2.s_utils import deflate_and_base64_encode +from saml2.s_utils import deflate_and_base64_encode, Unsupported import logging +import M2Crypto +from saml2.sigver import RSA_SHA1, rsa_load, x509_rsa_loads, pem_format logger = logging.getLogger(__name__) @@ -51,7 +54,9 @@ FORM_SPEC = """
""" -def http_form_post_message(message, location, relay_state="", typ="SAMLRequest"): + +def http_form_post_message(message, location, relay_state="", + typ="SAMLRequest"): """The HTTP POST binding defines a mechanism by which SAML protocol messages may be transmitted within the base64-encoded content of a HTML form control. @@ -93,7 +98,47 @@ def http_form_post_message(message, location, relay_state="", typ="SAMLRequest") # """ # return {"headers": [("Content-type", "text/xml")], "data": message} -def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"): + +class BadSignature(Exception): + """The signature is invalid.""" + pass + + +def sha1_digest(msg): + return hashlib.sha1(msg).digest() + + +class Signer(object): + """Abstract base class for signing algorithms.""" + def sign(self, msg, key): + """Sign ``msg`` with ``key`` and return the signature.""" + raise NotImplementedError + + def verify(self, msg, sig, key): + """Return True if ``sig`` is a valid signature for ``msg``.""" + raise NotImplementedError + + +class RSASigner(Signer): + def __init__(self, digest, algo): + self.digest = digest + self.algo = algo + + def sign(self, msg, key): + return key.sign(self.digest(msg), self.algo) + + def verify(self, msg, sig, key): + try: + return key.verify(self.digest(msg), sig, self.algo) + except M2Crypto.RSA.RSAError, e: + raise BadSignature(e) + + +REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"] +RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"] + +def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", + sigalg=None, key=None): """The HTTP Redirect binding defines a mechanism by which SAML protocol messages can be transmitted within URL parameters. Messages are encoded for use with this binding using a URL encoding @@ -104,13 +149,21 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"): :param message: The message :param location: Where the message should be posted to :param relay_state: for preserving and conveying state information + :param typ: What type of message it is SAMLRequest/SAMLResponse/SAMLart + :param sigalg: The signature algorithm to use. + :param key: Key to use for signing :return: A tuple containing header information and a HTML message. """ if not isinstance(message, basestring): message = "%s" % (message,) + _order = None if typ in ["SAMLRequest", "SAMLResponse"]: + if typ == "SAMLRequest": + _order = REQ_ORDER + else: + _order = RESP_ORDER args = {typ: deflate_and_base64_encode(message)} elif typ == "SAMLart": args = {typ: message} @@ -120,16 +173,67 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest"): if relay_state: args["RelayState"] = relay_state + if sigalg: + # sigalgs + # http://www.w3.org/2000/09/xmldsig#dsa-sha1 + # http://www.w3.org/2000/09/xmldsig#rsa-sha1 + + args["SigAlg"] = sigalg + + if sigalg == RSA_SHA1: + signer = RSASigner(sha1_digest, "sha1") + string = "&".join([urllib.urlencode({k: args[k]}) for k in _order]) + args["Signature"] = base64.b64encode(signer.sign(string, key)) + string = urllib.urlencode(args) + else: + raise Unsupported("Signing algorithm") + else: + string = urllib.urlencode(args) + glue_char = "&" if urlparse.urlparse(location).query else "?" - login_url = glue_char.join([location, urllib.urlencode(args)]) + login_url = glue_char.join([location, string]) headers = [('Location', login_url)] body = [] - return {"headers":headers, "data":body} + return {"headers": headers, "data": body} + + +def verify_redirect_signature(info, cert): + """ + + :param info: A dictionary as produced by parse_qs, means all values are + lists. + :param cert: A certificate to use when verifying the signature + :return: True, if signature verified + """ + + if info["SigAlg"][0] == RSA_SHA1: + if "SAMLRequest" in info: + _order = REQ_ORDER + elif "SAMLResponse" in info: + _order = RESP_ORDER + else: + raise Unsupported( + "Verifying signature on something that should not be signed") + signer = RSASigner(sha1_digest, "sha1") + args = info.copy() + del args["Signature"] # everything but the signature + string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order]) + _key = x509_rsa_loads(pem_format(cert)) + _sign = base64.b64decode(info["Signature"][0]) + try: + signer.verify(string, _sign, _key) + return True + except BadSignature: + return False + else: + raise Unsupported("Signature algorithm: %s" % info["SigAlg"]) + DUMMY_NAMESPACE = "http://example.org/" PREFIX = '' + def make_soap_enveloped_saml_thingy(thingy, header_parts=None): """ Returns a soap envelope containing a SAML request as a text string. @@ -170,21 +274,24 @@ def make_soap_enveloped_saml_thingy(thingy, header_parts=None): cut1 = _str[j:i + len(DUMMY_NAMESPACE) + 1] _str = _str.replace(cut1, "") first = _str.find("<%s:FuddleMuddle" % (cut1[6:9],)) - last = _str.find(">", first+14) - cut2 = _str[first:last+1] + last = _str.find(">", first + 14) + cut2 = _str[first:last + 1] return _str.replace(cut2, thingy) else: thingy.become_child_element_of(body) return ElementTree.tostring(envelope, encoding="UTF-8") + def http_soap_message(message): return {"headers": [("Content-type", "application/soap+xml")], "data": make_soap_enveloped_saml_thingy(message)} + def http_paos(message, extra=None): - return {"headers":[("Content-type", "application/soap+xml")], + return {"headers": [("Content-type", "application/soap+xml")], "data": make_soap_enveloped_saml_thingy(message, extra)} + def parse_soap_enveloped_saml(text, body_class, header_class=None): """Parses a SOAP enveloped SAML thing and returns header parts and body @@ -205,7 +312,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): body = saml2.create_class_from_element_tree(body_class, sub) except Exception: raise Exception( - "Wrong body type (%s) in SOAP envelope" % sub.tag) + "Wrong body type (%s) in SOAP envelope" % sub.tag) elif part.tag == '{%s}Header' % NAMESPACE: if not header_class: raise Exception("Header where I didn't expect one") @@ -226,13 +333,15 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): PACKING = { saml2.BINDING_HTTP_REDIRECT: http_redirect_message, saml2.BINDING_HTTP_POST: http_form_post_message, - } +} -def packager( identifier ): + +def packager(identifier): try: return PACKING[identifier] except KeyError: raise Exception("Unkown binding type: %s" % identifier) + def factory(binding, message, location, relay_state="", typ="SAMLRequest"): return PACKING[binding](message, location, relay_state, typ) \ No newline at end of file diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py index cb5e1d5..1b28dc1 100644 --- a/src/saml2/s_utils.py +++ b/src/saml2/s_utils.py @@ -9,7 +9,7 @@ import sys import hmac # from python 2.5 -if sys.version_info >= (2,5): +if sys.version_info >= (2, 5): import hashlib else: # before python 2.5 import sha @@ -27,36 +27,51 @@ import zlib logger = logging.getLogger(__name__) + class SamlException(Exception): pass + class RequestVersionTooLow(SamlException): pass + class RequestVersionTooHigh(SamlException): pass + class UnknownPrincipal(SamlException): pass -class UnsupportedBinding(SamlException): + +class Unsupported(SamlException): pass + +class UnsupportedBinding(Unsupported): + pass + + class VersionMismatch(Exception): pass + class Unknown(Exception): pass + class OtherError(Exception): pass + class MissingValue(Exception): pass + class PolicyError(Exception): pass + class BadRequest(Exception): pass @@ -73,28 +88,29 @@ EXCEPTION2STATUS = { Exception: samlp.STATUS_AUTHN_FAILED, } -GENERIC_DOMAINS = "aero", "asia", "biz", "cat", "com", "coop", \ - "edu", "gov", "info", "int", "jobs", "mil", "mobi", "museum", \ - "name", "net", "org", "pro", "tel", "travel" +GENERIC_DOMAINS = ["aero", "asia", "biz", "cat", "com", "coop", "edu", + "gov", "info", "int", "jobs", "mil", "mobi", "museum", + "name", "net", "org", "pro", "tel", "travel"] -def valid_email(emailaddress, domains = GENERIC_DOMAINS): + +def valid_email(emailaddress, domains=GENERIC_DOMAINS): """Checks for a syntactically valid email address.""" # Email address must be at least 6 characters in total. # Assuming noone may have addresses of the type a@com if len(emailaddress) < 6: - return False # Address too short. + return False # Address too short. # Split up email address into parts. try: localpart, domainname = emailaddress.rsplit('@', 1) host, toplevel = domainname.rsplit('.', 1) except ValueError: - return False # Address does not have enough parts. + return False # Address does not have enough parts. # Check for Country code or Generic Domain. if len(toplevel) != 2 and toplevel not in domains: - return False # Not a domain name. + return False # Not a domain name. for i in '-_.%+.': localpart = localpart.replace(i, "") @@ -102,27 +118,30 @@ def valid_email(emailaddress, domains = GENERIC_DOMAINS): host = host.replace(i, "") if localpart.isalnum() and host.isalnum(): - return True # Email address is fine. + return True # Email address is fine. else: - return False # Email address has funny characters. + return False # Email address has funny characters. -def decode_base64_and_inflate( string ): + +def decode_base64_and_inflate(string): """ base64 decodes and then inflates according to RFC1951 :param string: a deflated and encoded string :return: the string after decoding and inflating """ - return zlib.decompress( base64.b64decode( string ) , -15) + return zlib.decompress(base64.b64decode(string), -15) -def deflate_and_base64_encode( string_val ): + +def deflate_and_base64_encode(string_val): """ Deflates and the base64 encodes a string :param string_val: The string to deflate and encode :return: The deflated and encoded string """ - return base64.b64encode( zlib.compress( string_val )[2:-4] ) + return base64.b64encode(zlib.compress(string_val)[2:-4]) + def rndstr(size=16): """ @@ -134,9 +153,11 @@ def rndstr(size=16): _basech = string.ascii_letters + string.digits return "".join([random.choice(_basech) for _ in range(size)]) + def sid(seed=""): """The hash of the server time + seed makes an unique SID for each session. - 128-bits long so it fulfills the SAML2 requirements which states 128-160 bits + 128-bits long so it fulfills the SAML2 requirements which states + 128-160 bits :param seed: A seed string :return: The hex version of the digest, prefixed by 'id-' to make it @@ -146,7 +167,8 @@ def sid(seed=""): ident.update(repr(time.time())) if seed: ident.update(seed) - return "id-"+ident.hexdigest() + return "id-" + ident.hexdigest() + def parse_attribute_map(filenames): """ @@ -168,6 +190,7 @@ def parse_attribute_map(filenames): return forward, backward + def identity_attribute(form, attribute, forward_map=None): if form == "friendly": if attribute.friendly_name: @@ -182,6 +205,7 @@ def identity_attribute(form, attribute, forward_map=None): #---------------------------------------------------------------------------- + def error_status_factory(info): if isinstance(info, Exception): try: @@ -194,39 +218,38 @@ def error_status_factory(info): status_code=samlp.StatusCode( value=samlp.STATUS_RESPONDER, status_code=samlp.StatusCode( - value=exc_val) - ), - ) + value=exc_val))) else: (errcode, text) = info status = samlp.Status( status_message=samlp.StatusMessage(text=text), status_code=samlp.StatusCode( value=samlp.STATUS_RESPONDER, - status_code=samlp.StatusCode(value=errcode) - ), - ) + status_code=samlp.StatusCode(value=errcode))) return status + def success_status_factory(): return samlp.Status(status_code=samlp.StatusCode( - value=samlp.STATUS_SUCCESS)) + value=samlp.STATUS_SUCCESS)) + def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER): return samlp.Status( status_message=samlp.StatusMessage(text=message), - status_code=samlp.StatusCode( - value=fro, - status_code=samlp.StatusCode(value=code))) + status_code=samlp.StatusCode(value=fro, + status_code=samlp.StatusCode(value=code))) + def assertion_factory(**kwargs): assertion = saml.Assertion(version=VERSION, id=sid(), - issue_instant=instant()) + issue_instant=instant()) for key, val in kwargs.items(): setattr(assertion, key, val) return assertion + def _attrval(val, typ=""): if isinstance(val, list) or isinstance(val, set): attrval = [saml.AttributeValue(text=v) for v in val] @@ -246,6 +269,7 @@ def _attrval(val, typ=""): # xmlns:xs="http://www.w3.org/2001/XMLSchema" # xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + def do_ava(val, typ=""): if isinstance(val, basestring): ava = saml.AttributeValue() @@ -253,7 +277,7 @@ def do_ava(val, typ=""): attrval = [ava] elif isinstance(val, list): attrval = [do_ava(v)[0] for v in val] - elif val or val == False: + elif val or val is False: ava = saml.AttributeValue() ava.set_text(val) attrval = [ava] @@ -268,6 +292,7 @@ def do_ava(val, typ=""): return attrval + def do_attribute(val, typ, key): attr = saml.Attribute() attrval = do_ava(val, typ) @@ -276,7 +301,7 @@ def do_attribute(val, typ, key): if isinstance(key, basestring): attr.name = key - elif isinstance(key, tuple): # 3-tuple or 2-tuple + elif isinstance(key, tuple): # 3-tuple or 2-tuple try: (name, nformat, friendly) = key except ValueError: @@ -290,6 +315,7 @@ def do_attribute(val, typ, key): attr.friendly_name = friendly return attr + def do_attributes(identity): attrs = [] if not identity: @@ -308,6 +334,7 @@ def do_attributes(identity): attrs.append(attr) return attrs + def do_attribute_statement(identity): """ :param identity: A dictionary with fiendly names as keys @@ -315,12 +342,14 @@ def do_attribute_statement(identity): """ return saml.AttributeStatement(attribute=do_attributes(identity)) + def factory(klass, **kwargs): instance = klass() for key, val in kwargs.items(): setattr(instance, key, val) return instance + def signature(secret, parts): """Generates a signature. """ @@ -334,6 +363,7 @@ def signature(secret, parts): return csum.hexdigest() + def verify_signature(secret, parts): """ Checks that the signature is correct """ if signature(secret, parts[:-1]) == parts[-1]: @@ -344,9 +374,10 @@ def verify_signature(secret, parts): FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#" + def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion): """ - 'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value ) '#' + 'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value) '#' Allowed attributes: TS the login time stamp RP the relying party entityID diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index dd647cf..f4e58fe 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -25,7 +25,9 @@ import logging import random import os import sys +from time import mktime import M2Crypto +from M2Crypto.X509 import load_cert_string from saml2.samlp import Response import xmldsig as ds @@ -38,7 +40,7 @@ from saml2 import VERSION from saml2.s_utils import sid -from saml2.time_util import instant +from saml2.time_util import instant, utc_now, str_to_time from tempfile import NamedTemporaryFile from subprocess import Popen, PIPE @@ -47,6 +49,9 @@ logger = logging.getLogger(__name__) SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature") +RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1" + + def signed(item): if SIG in item.c_children.keys() and item.signature: return True @@ -62,6 +67,7 @@ def signed(item): return False + def get_xmlsec_binary(paths=None): """ Tries to find the xmlsec1 binary. @@ -75,7 +81,7 @@ def get_xmlsec_binary(paths=None): bin_name = "xmlsec1" elif os.name == "nt": bin_name = "xmlsec1.exe" - else: # Default !? + else: # Default !? bin_name = "xmlsec1" if paths: @@ -109,47 +115,36 @@ ENC_KEY_CLASS = "EncryptedKey" _TEST_ = True + class SignatureError(Exception): pass + class XmlsecError(Exception): pass + class MissingKey(Exception): pass + class DecryptError(Exception): pass # -------------------------------------------------------------------------- -#def make_signed_instance(klass, spec, seccont, base64encode=False): -# """ Will only return signed instance if the signature -# preamble is present -# -# :param klass: The type of class the instance should be -# :param spec: The specification of attributes and children of the class -# :param seccont: The security context (instance of SecurityContext) -# :param base64encode: Whether the attribute values should be base64 encoded -# :return: A signed (or possibly unsigned) instance of the class -# """ -# if "signature" in spec: -# signed_xml = seccont.sign_statement_using_xmlsec("%s" % instance, -# class_name(instance), instance.id) -# return create_class_from_xml_string(instance.__class__, signed_xml) -# else: -# return make_instance(klass, spec, base64encode) def xmlsec_version(execname): - com_list = [execname,"--version"] + com_list = [execname, "--version"] pof = Popen(com_list, stderr=PIPE, stdout=PIPE) try: return pof.stdout.read().split(" ")[1] except Exception: return "" - + + def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, - base64encode=False, elements_to_sign=None): + base64encode=False, elements_to_sign=None): """ Creates a class instance with a specified value, the specified class instance may be a value on a property in a defined class instance. @@ -169,14 +164,14 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, if isinstance(val, dict): cinst = _instance(klass, val, seccont, base64encode=base64encode, - elements_to_sign=elements_to_sign) + elements_to_sign=elements_to_sign) else: try: cinst = klass().set_text(val) except ValueError: if not part: cis = [_make_vals(sval, klass, seccont, klass_inst, prop, - True, base64encode, elements_to_sign) for sval in val] + True, base64encode, elements_to_sign) for sval in val] setattr(klass_inst, prop, cis) else: raise @@ -188,6 +183,7 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, cis = [cinst] setattr(klass_inst, prop, cis) + def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): instance = klass() @@ -208,30 +204,31 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): #print "## %s, %s" % (prop, klassdef) if prop in ava: #print "### %s" % ava[prop] - if isinstance(klassdef, list): # means there can be a list of values + if isinstance(klassdef, list): + # means there can be a list of values _make_vals(ava[prop], klassdef[0], seccont, instance, prop, - base64encode=base64encode, - elements_to_sign=elements_to_sign) + base64encode=base64encode, + elements_to_sign=elements_to_sign) else: cis = _make_vals(ava[prop], klassdef, seccont, instance, prop, - True, base64encode, elements_to_sign) + True, base64encode, elements_to_sign) setattr(instance, prop, cis) if "extension_elements" in ava: for item in ava["extension_elements"]: instance.extension_elements.append( - ExtensionElement(item["tag"]).loadd(item)) + ExtensionElement(item["tag"]).loadd(item)) if "extension_attributes" in ava: for key, val in ava["extension_attributes"].items(): instance.extension_attributes[key] = val - - + if "signature" in ava: elements_to_sign.append((class_name(instance), instance.id)) return instance + def signed_instance_factory(instance, seccont, elements_to_sign=None): """ @@ -243,8 +240,8 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): if elements_to_sign: signed_xml = "%s" % instance for (node_name, nodeid) in elements_to_sign: - signed_xml = seccont.sign_statement_using_xmlsec(signed_xml, - klass_namn=node_name, nodeid=nodeid) + signed_xml = seccont.sign_statement_using_xmlsec( + signed_xml, klass_namn=node_name, nodeid=nodeid) #print "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" #print "%s" % signed_xml @@ -255,6 +252,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): # -------------------------------------------------------------------------- + def create_id(): """ Create a string of 40 random characters from the set [a-p], can be used as a unique identifier of objects. @@ -266,6 +264,7 @@ def create_id(): ret += chr(random.randint(0, 15) + ord('a')) return ret + def make_temp(string, suffix="", decode=True): """ xmlsec needs files in some cases where only strings exist, hence the need for this function. It creates a temporary file with the @@ -288,9 +287,34 @@ def make_temp(string, suffix="", decode=True): ntf.seek(0) return ntf, ntf.name + def split_len(seq, length): - return [seq[i:i+length] for i in range(0, len(seq), length)] - + return [seq[i:i + length] for i in range(0, len(seq), length)] + +# -------------------------------------------------------------------------- + +M2_TIME_FORMAT = "%b %d %H:%M:%S %Y" + + +def to_time(_time): + assert _time.endswith(" GMT") + _time = _time[:-4] + return mktime(str_to_time(_time, M2_TIME_FORMAT)) + + +def active_cert(key): + cert_str = pem_format(key) + certificate = load_cert_string(cert_str) + not_before = to_time(str(certificate.get_not_before())) + not_after = to_time(str(certificate.get_not_after())) + try: + assert not_before < utc_now() + assert not_after > utc_now() + return True + except AssertionError: + return False + + def cert_from_key_info(key_info): """ Get all X509 certs from a KeyInfo instance. Care is taken to make sure that the certs are continues sequences of bytes. @@ -307,11 +331,15 @@ def cert_from_key_info(key_info): #print "X509Data",x509_data x509_certificate = x509_data.x509_certificate cert = x509_certificate.text.strip() - cert = "\n".join(split_len("".join([ - s.strip() for s in cert.split()]),64)) - res.append(cert) + cert = "\n".join(split_len("".join([s.strip() for s in + cert.split()]), 64)) + if active_cert(cert): + res.append(cert) + else: + logger.info("Inactive cert") return res + def cert_from_key_info_dict(key_info): """ Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure that the certs are continues sequences of bytes. @@ -330,11 +358,15 @@ def cert_from_key_info_dict(key_info): for x509_data in key_info["x509_data"]: x509_certificate = x509_data["x509_certificate"] cert = x509_certificate["text"].strip() - cert = "\n".join(split_len("".join([ - s.strip() for s in cert.split()]),64)) - res.append(cert) + cert = "\n".join(split_len("".join([s.strip() for s in + cert.split()]), 64)) + if active_cert(cert): + res.append(cert) + else: + logger.info("Inactive cert") return res + def cert_from_instance(instance): """ Find certificates that are part of an instance @@ -350,25 +382,30 @@ def cert_from_instance(instance): from M2Crypto.__m2crypto import bn_to_mpi from M2Crypto.__m2crypto import hex_to_bn + def intarr2long(arr): return long(''.join(["%02x" % byte for byte in arr]), 16) + def dehexlify(bi): s = hexlify(bi) - return [int(s[i]+s[i+1], 16) for i in range(0,len(s),2)] + return [int(s[i] + s[i + 1], 16) for i in range(0, len(s), 2)] + def long_to_mpi(num): """Converts a python integer or long to OpenSSL MPInt used by M2Crypto. Borrowed from Snowball.Shared.Crypto""" - h = hex(num)[2:] # strip leading 0x in string + h = hex(num)[2:] # strip leading 0x in string if len(h) % 2 == 1: - h = '0' + h # add leading 0 to get even number of hexdigits - return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum + h = '0' + h # add leading 0 to get even number of hexdigits + return bn_to_mpi(hex_to_bn(h)) # convert using OpenSSL BinNum + def base64_to_long(data): _d = base64.urlsafe_b64decode(data + '==') return intarr2long(dehexlify(_d)) + def key_from_key_value(key_info): res = [] for value in key_info.key_value: @@ -376,10 +413,11 @@ def key_from_key_value(key_info): e = base64_to_long(value.rsa_key_value.exponent) m = base64_to_long(value.rsa_key_value.modulus) key = M2Crypto.RSA.new_pub_key((long_to_mpi(e), - long_to_mpi(m))) + long_to_mpi(m))) res.append(key) return res + def key_from_key_value_dict(key_info): res = [] if not "key_value" in key_info: @@ -396,10 +434,28 @@ def key_from_key_value_dict(key_info): # ============================================================================= + +def rsa_load(filename): + """Read a PEM-encoded RSA key pair from a file.""" + return M2Crypto.RSA.load_key(filename, M2Crypto.util.no_passphrase_callback) + + +def rsa_loads(key): + """Read a PEM-encoded RSA key pair from a string.""" + return M2Crypto.RSA.load_key_string(key, + M2Crypto.util.no_passphrase_callback) + + +def x509_rsa_loads(string): + cert = M2Crypto.X509.load_cert_string(string) + return cert.get_pubkey().get_rsa() + + def pem_format(key): return "\n".join(["-----BEGIN CERTIFICATE-----", - key,"-----END CERTIFICATE-----"]) + key, "-----END CERTIFICATE-----"]) + def parse_xmlsec_output(output): """ Parse the output from xmlsec to try to find out if the command was successfull or not. @@ -416,12 +472,13 @@ def parse_xmlsec_output(output): __DEBUG = 0 -LOG_LINE = 60*"="+"\n%s\n"+60*"-"+"\n%s"+60*"=" -LOG_LINE_2 = 60*"="+"\n%s\n%s\n"+60*"-"+"\n%s"+60*"=" +LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" +LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" + def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem", - node_name=NODE_NAME, debug=False, node_id=None, - id_attr=""): + node_name=NODE_NAME, debug=False, node_id=None, + id_attr=""): """ Verifies the signature of a XML document. :param enctext: The signed XML document @@ -481,6 +538,7 @@ def verify_signature(enctext, xmlsec_binary, cert_file=None, cert_type="pem", # --------------------------------------------------------------------------- + def read_cert_from_file(cert_file, cert_type): """ Reads a certificate from a file. The assumption is that there is only one certificate in the file @@ -516,6 +574,7 @@ def read_cert_from_file(cert_file, cert_type): data = open(cert_file).read() return base64.b64encode(str(data)) + def security_context(conf, debug=None): """ Creates a security context based on the configuration @@ -535,14 +594,15 @@ def security_context(conf, debug=None): _only_md = False return SecurityContext(conf.xmlsec_binary, conf.key_file, - cert_file=conf.cert_file, metadata=metadata, - debug=debug, only_use_keys_in_metadata=_only_md) + cert_file=conf.cert_file, metadata=metadata, + debug=debug, only_use_keys_in_metadata=_only_md) + class SecurityContext(object): - def __init__(self, xmlsec_binary, key_file="", key_type= "pem", - cert_file="", cert_type="pem", metadata=None, - debug=False, template="", encrypt_key_type="des-192", - only_use_keys_in_metadata=False): + def __init__(self, xmlsec_binary, key_file="", key_type="pem", + cert_file="", cert_type="pem", metadata=None, + debug=False, template="", encrypt_key_type="des-192", + only_use_keys_in_metadata=False): self.xmlsec = xmlsec_binary @@ -592,12 +652,9 @@ class SecurityContext(object): _, fil = make_temp("%s" % text, decode=False) ntf = NamedTemporaryFile() - com_list = [self.xmlsec, "--encrypt", - "--pubkey-pem", recv_key, - "--session-key", key_type, - "--xml-data", fil, - "--output", ntf.name, - template] + com_list = [self.xmlsec, "--encrypt", "--pubkey-pem", recv_key, + "--session-key", key_type, "--xml-data", fil, + "--output", ntf.name, template] logger.debug("Encryption command: %s" % " ".join(com_list)) @@ -625,11 +682,9 @@ class SecurityContext(object): _, fil = make_temp("%s" % enctext, decode=False) ntf = NamedTemporaryFile() - com_list = [self.xmlsec, "--decrypt", - "--privkey-pem", self.key_file, - "--output", ntf.name, - "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, - fil] + com_list = [self.xmlsec, "--decrypt", "--privkey-pem", + self.key_file, "--output", ntf.name, + "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS, fil] logger.debug("Decrypt command: %s" % " ".join(com_list)) @@ -646,9 +701,8 @@ class SecurityContext(object): ntf.seek(0) return ntf.read() - def verify_signature(self, enctext, cert_file=None, cert_type="pem", - node_name=NODE_NAME, node_id=None, id_attr=""): + node_name=NODE_NAME, node_id=None, id_attr=""): """ Verifies the signature of a XML document. :param enctext: The XML document as a string @@ -773,22 +827,22 @@ class SecurityContext(object): must, origdoc) def correctly_signed_authn_query(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "authn_query", must, origdoc) def correctly_signed_logout_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "logout_request", must, origdoc) def correctly_signed_logout_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "logout_response", must, origdoc) def correctly_signed_attribute_query(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "attribute_query", must, origdoc) @@ -799,31 +853,31 @@ class SecurityContext(object): origdoc) def correctly_signed_authz_decision_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "authz_decision_response", must, origdoc) def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "name_id_mapping_request", must, origdoc) def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "name_id_mapping_response", must, origdoc) def correctly_signed_artifact_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "artifact_request", must, origdoc) def correctly_signed_artifact_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "artifact_response", must, origdoc) @@ -835,19 +889,19 @@ class SecurityContext(object): must, origdoc) def correctly_signed_manage_name_id_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "manage_name_id_response", must, origdoc) def correctly_signed_assertion_id_request(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "assertion_id_request", must, origdoc) def correctly_signed_assertion_id_response(self, decoded_xml, must=False, - origdoc=None): + origdoc=None): return self.correctly_signed_message(decoded_xml, "assertion", must, origdoc) @@ -882,18 +936,16 @@ class SecurityContext(object): try: self._check_signature(decoded_xml, assertion, - class_name(assertion), origdoc) + class_name(assertion), origdoc) except Exception, exc: logger.error("correctly_signed_response: %s" % exc) raise return response - #-------------------------------------------------------------------------- # SIGNATURE PART #-------------------------------------------------------------------------- - def sign_statement_using_xmlsec(self, statement, klass_namn, key=None, key_file=None, nodeid=None, id_attr=""): """Sign a SAML statement using xmlsec. @@ -917,7 +969,6 @@ class SecurityContext(object): _, fil = make_temp("%s" % statement, decode=False) - ntf = NamedTemporaryFile() com_list = [self.xmlsec, "--sign", @@ -975,10 +1026,8 @@ class SecurityContext(object): :return: The signed statement """ - return self.sign_statement_using_xmlsec(statement, - class_name(samlp.AttributeQuery()), - key, key_file, nodeid, - id_attr=id_attr) + return self.sign_statement_using_xmlsec(statement, class_name( + samlp.AttributeQuery()), key, key_file, nodeid, id_attr=id_attr) def multiple_signatures(self, statement, to_sign, key=None, key_file=None): """ @@ -991,15 +1040,15 @@ class SecurityContext(object): :param key_file: A file that contains the key to be used :return: A possibly multiple signed statement """ - for (item, id, id_attr) in to_sign: - if not id: + for (item, sid, id_attr) in to_sign: + if not sid: if not item.id: - id = item.id = sid() + sid = item.id = sid() else: - id = item.id + sid = item.id if not item.signature: - item.signature = pre_signature_part(id, self.cert_file) + item.signature = pre_signature_part(sid, self.cert_file) statement = self.sign_statement_using_xmlsec(statement, class_name(item), @@ -1024,32 +1073,30 @@ def pre_signature_part(ident, public_key=None, identifier=None): :return: A preset signature part """ - signature_method = ds.SignatureMethod(algorithm = ds.SIG_RSA_SHA1) + signature_method = ds.SignatureMethod(algorithm=ds.SIG_RSA_SHA1) canonicalization_method = ds.CanonicalizationMethod( - algorithm = ds.ALG_EXC_C14N) - trans0 = ds.Transform(algorithm = ds.TRANSFORM_ENVELOPED) - trans1 = ds.Transform(algorithm = ds.ALG_EXC_C14N) - transforms = ds.Transforms(transform = [trans0, trans1]) - digest_method = ds.DigestMethod(algorithm = ds.DIGEST_SHA1) + algorithm=ds.ALG_EXC_C14N) + trans0 = ds.Transform(algorithm=ds.TRANSFORM_ENVELOPED) + trans1 = ds.Transform(algorithm=ds.ALG_EXC_C14N) + transforms = ds.Transforms(transform=[trans0, trans1]) + digest_method = ds.DigestMethod(algorithm=ds.DIGEST_SHA1) - reference = ds.Reference(uri = "#%s" % ident, - digest_value = ds.DigestValue(), - transforms = transforms, - digest_method = digest_method) + reference = ds.Reference(uri="#%s" % ident, digest_value=ds.DigestValue(), + transforms=transforms, digest_method=digest_method) - signed_info = ds.SignedInfo(signature_method = signature_method, - canonicalization_method = canonicalization_method, - reference = reference) + signed_info = ds.SignedInfo(signature_method=signature_method, + canonicalization_method=canonicalization_method, + reference=reference) - signature = ds.Signature(signed_info=signed_info, - signature_value=ds.SignatureValue()) + signature = ds.Signature(signed_info=signed_info, + signature_value=ds.SignatureValue()) if identifier: signature.id = "Signature%d" % identifier if public_key: - x509_data = ds.X509Data(x509_certificate=[ds.X509Certificate( - text=public_key)]) + x509_data = ds.X509Data( + x509_certificate=[ds.X509Certificate(text=public_key)]) key_info = ds.KeyInfo(x509_data=x509_data) signature.key_info = key_info From 746a45b5fe445ab044909a7a1d9620ef16636912 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Thu, 21 Feb 2013 13:08:20 +0100 Subject: [PATCH 8/8] Added support for signing/verifying messages when using the HTTP-Redirect binding. --- src/saml2/pack.py | 93 +++---------------------------- src/saml2/sigver.py | 89 +++++++++++++++++++++++++++-- tests/test_70_redirect_signing.py | 44 +++++++++++++++ 3 files changed, 134 insertions(+), 92 deletions(-) create mode 100644 tests/test_70_redirect_signing.py diff --git a/src/saml2/pack.py b/src/saml2/pack.py index 4cf5996..eed060c 100644 --- a/src/saml2/pack.py +++ b/src/saml2/pack.py @@ -23,15 +23,18 @@ Bindings normally consists of three parts: - how to package the information - which protocol to use """ -import hashlib import urlparse import saml2 import base64 import urllib -from saml2.s_utils import deflate_and_base64_encode, Unsupported +from saml2.s_utils import deflate_and_base64_encode +from saml2.s_utils import Unsupported import logging -import M2Crypto -from saml2.sigver import RSA_SHA1, rsa_load, x509_rsa_loads, pem_format +from saml2.sigver import RSA_SHA1 +from saml2.sigver import REQ_ORDER +from saml2.sigver import RESP_ORDER +from saml2.sigver import RSASigner +from saml2.sigver import sha1_digest logger = logging.getLogger(__name__) @@ -86,56 +89,6 @@ def http_form_post_message(message, location, relay_state="", return {"headers": [("Content-type", "text/html")], "data": response} -##noinspection PyUnresolvedReferences -#def http_post_message(message, location, relay_state="", typ="SAMLRequest"): -# """ -# -# :param message: -# :param location: -# :param relay_state: -# :param typ: -# :return: -# """ -# return {"headers": [("Content-type", "text/xml")], "data": message} - - -class BadSignature(Exception): - """The signature is invalid.""" - pass - - -def sha1_digest(msg): - return hashlib.sha1(msg).digest() - - -class Signer(object): - """Abstract base class for signing algorithms.""" - def sign(self, msg, key): - """Sign ``msg`` with ``key`` and return the signature.""" - raise NotImplementedError - - def verify(self, msg, sig, key): - """Return True if ``sig`` is a valid signature for ``msg``.""" - raise NotImplementedError - - -class RSASigner(Signer): - def __init__(self, digest, algo): - self.digest = digest - self.algo = algo - - def sign(self, msg, key): - return key.sign(self.digest(msg), self.algo) - - def verify(self, msg, sig, key): - try: - return key.verify(self.digest(msg), sig, self.algo) - except M2Crypto.RSA.RSAError, e: - raise BadSignature(e) - - -REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"] -RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"] def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", sigalg=None, key=None): @@ -198,38 +151,6 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", return {"headers": headers, "data": body} -def verify_redirect_signature(info, cert): - """ - - :param info: A dictionary as produced by parse_qs, means all values are - lists. - :param cert: A certificate to use when verifying the signature - :return: True, if signature verified - """ - - if info["SigAlg"][0] == RSA_SHA1: - if "SAMLRequest" in info: - _order = REQ_ORDER - elif "SAMLResponse" in info: - _order = RESP_ORDER - else: - raise Unsupported( - "Verifying signature on something that should not be signed") - signer = RSASigner(sha1_digest, "sha1") - args = info.copy() - del args["Signature"] # everything but the signature - string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order]) - _key = x509_rsa_loads(pem_format(cert)) - _sign = base64.b64decode(info["Signature"][0]) - try: - signer.verify(string, _sign, _key) - return True - except BadSignature: - return False - else: - raise Unsupported("Signature algorithm: %s" % info["SigAlg"]) - - DUMMY_NAMESPACE = "http://example.org/" PREFIX = '' diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index f4e58fe..2422a37 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -21,11 +21,13 @@ Based on the use of xmlsec1 binaries and not the python xmlsec module. import base64 from binascii import hexlify +import hashlib import logging import random import os import sys from time import mktime +import urllib import M2Crypto from M2Crypto.X509 import load_cert_string from saml2.samlp import Response @@ -39,8 +41,11 @@ from saml2 import ExtensionElement from saml2 import VERSION from saml2.s_utils import sid +from saml2.s_utils import Unsupported -from saml2.time_util import instant, utc_now, str_to_time +from saml2.time_util import instant +from saml2.time_util import utc_now +from saml2.time_util import str_to_time from tempfile import NamedTemporaryFile from subprocess import Popen, PIPE @@ -315,7 +320,7 @@ def active_cert(key): return False -def cert_from_key_info(key_info): +def cert_from_key_info(key_info, ignore_age=False): """ Get all X509 certs from a KeyInfo instance. Care is taken to make sure that the certs are continues sequences of bytes. @@ -333,14 +338,14 @@ def cert_from_key_info(key_info): cert = x509_certificate.text.strip() cert = "\n".join(split_len("".join([s.strip() for s in cert.split()]), 64)) - if active_cert(cert): + if ignore_age or active_cert(cert): res.append(cert) else: logger.info("Inactive cert") return res -def cert_from_key_info_dict(key_info): +def cert_from_key_info_dict(key_info, ignore_age=False): """ Get all X509 certs from a KeyInfo dictionary. Care is taken to make sure that the certs are continues sequences of bytes. @@ -360,7 +365,7 @@ def cert_from_key_info_dict(key_info): cert = x509_certificate["text"].strip() cert = "\n".join(split_len("".join([s.strip() for s in cert.split()]), 64)) - if active_cert(cert): + if ignore_age or active_cert(cert): res.append(cert) else: logger.info("Inactive cert") @@ -375,7 +380,8 @@ def cert_from_instance(instance): """ if instance.signature: if instance.signature.key_info: - return cert_from_key_info(instance.signature.key_info) + return cert_from_key_info(instance.signature.key_info, + ignore_age=True) return [] # ============================================================================= @@ -472,6 +478,77 @@ def parse_xmlsec_output(output): __DEBUG = 0 + +class BadSignature(Exception): + """The signature is invalid.""" + pass + + +def sha1_digest(msg): + return hashlib.sha1(msg).digest() + + +class Signer(object): + """Abstract base class for signing algorithms.""" + def sign(self, msg, key): + """Sign ``msg`` with ``key`` and return the signature.""" + raise NotImplementedError + + def verify(self, msg, sig, key): + """Return True if ``sig`` is a valid signature for ``msg``.""" + raise NotImplementedError + + +class RSASigner(Signer): + def __init__(self, digest, algo): + self.digest = digest + self.algo = algo + + def sign(self, msg, key): + return key.sign(self.digest(msg), self.algo) + + def verify(self, msg, sig, key): + try: + return key.verify(self.digest(msg), sig, self.algo) + except M2Crypto.RSA.RSAError, e: + raise BadSignature(e) + + +REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"] +RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"] + + +def verify_redirect_signature(info, cert): + """ + + :param info: A dictionary as produced by parse_qs, means all values are + lists. + :param cert: A certificate to use when verifying the signature + :return: True, if signature verified + """ + + if info["SigAlg"][0] == RSA_SHA1: + if "SAMLRequest" in info: + _order = REQ_ORDER + elif "SAMLResponse" in info: + _order = RESP_ORDER + else: + raise Unsupported( + "Verifying signature on something that should not be signed") + signer = RSASigner(sha1_digest, "sha1") + args = info.copy() + del args["Signature"] # everything but the signature + string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order]) + _key = x509_rsa_loads(pem_format(cert)) + _sign = base64.b64decode(info["Signature"][0]) + try: + signer.verify(string, _sign, _key) + return True + except BadSignature: + return False + else: + raise Unsupported("Signature algorithm: %s" % info["SigAlg"]) + LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" diff --git a/tests/test_70_redirect_signing.py b/tests/test_70_redirect_signing.py new file mode 100644 index 0000000..6dd9fc6 --- /dev/null +++ b/tests/test_70_redirect_signing.py @@ -0,0 +1,44 @@ +from saml2.pack import http_redirect_message +from saml2.sigver import verify_redirect_signature +from saml2.sigver import RSA_SHA1 +from saml2.server import Server +from saml2 import BINDING_HTTP_REDIRECT +from saml2.client import Saml2Client +from saml2.config import SPConfig +from saml2.sigver import rsa_load +from urlparse import parse_qs + +__author__ = 'rolandh' + +idp = Server(config_file="idp_all_conf") + +conf = SPConfig() +conf.load_file("servera_conf") +sp = Saml2Client(conf) + +def test(): + srvs = sp.metadata.single_sign_on_service(idp.config.entityid, + BINDING_HTTP_REDIRECT) + + destination = srvs[0]["location"] + req = sp.create_authn_request(destination, id="id1") + + try: + key = sp.sec.key + except AttributeError: + key = rsa_load(sp.sec.key_file) + + info = http_redirect_message(req, destination, relay_state="RS", + typ="SAMLRequest", sigalg=RSA_SHA1, key=key) + + verified_ok = False + + for param, val in info["headers"]: + if param == "Location": + _dict = parse_qs(val.split("?")[1]) + _certs = idp.metadata.certs(sp.config.entityid, "any", "signing") + for cert in _certs: + if verify_redirect_signature(_dict, cert): + verified_ok = True + + assert verified_ok \ No newline at end of file