correctly restoring environ["wsgi.input"] after reading POST content

This commit is contained in:
Patrick Brosi
2014-09-29 17:40:58 +02:00
parent 6c1b963a64
commit f5012004d3

View File

@@ -1,6 +1,6 @@
# #
""" """
A plugin that allows you to use SAML2 SSO as authentication A plugin that allows you to use SAML2 SSO as authentication
and SAML2 attribute aggregations as metadata collector in your and SAML2 attribute aggregations as metadata collector in your
WSGI application. WSGI application.
@@ -49,20 +49,20 @@ PAOS_HEADER_INFO = 'ver="%s";"%s"' % (paos.NAMESPACE, ECP_SERVICE)
def construct_came_from(environ): def construct_came_from(environ):
""" The URL that the user used when the process where interupted """ The URL that the user used when the process where interupted
for single-sign-on processing. """ for single-sign-on processing. """
came_from = environ.get("PATH_INFO") came_from = environ.get("PATH_INFO")
qstr = environ.get("QUERY_STRING", "") qstr = environ.get("QUERY_STRING", "")
if qstr: if qstr:
came_from += '?' + qstr came_from += '?' + qstr
return came_from return came_from
def cgi_field_storage_to_dict(field_storage): def cgi_field_storage_to_dict(field_storage):
"""Get a plain dictionary, rather than the '.value' system used by the """Get a plain dictionary, rather than the '.value' system used by the
cgi module.""" cgi module."""
params = {} params = {}
for key in field_storage.keys(): for key in field_storage.keys():
try: try:
@@ -70,26 +70,9 @@ def cgi_field_storage_to_dict(field_storage):
except AttributeError: except AttributeError:
if isinstance(field_storage[key], basestring): if isinstance(field_storage[key], basestring):
params[key] = field_storage[key] params[key] = field_storage[key]
return params return params
def get_body(environ):
length = int(environ["CONTENT_LENGTH"])
try:
body = environ["wsgi.input"].read(length)
except Exception, excp:
logger.exception("Exception while reading post: %s" % (excp,))
raise
# restore what I might have upset
from StringIO import StringIO
environ['wsgi.input'] = StringIO(body)
environ['s2repoze.body'] = body
return body
def exception_trace(tag, exc, log): def exception_trace(tag, exc, log):
message = traceback.format_exception(*sys.exc_info()) message = traceback.format_exception(*sys.exc_info())
log.error("[%s] ExcList: %s" % (tag, "".join(message),)) log.error("[%s] ExcList: %s" % (tag, "".join(message),))
@@ -113,7 +96,7 @@ class ECP_response(object):
class SAML2Plugin(object): class SAML2Plugin(object):
implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider) implements(IChallenger, IIdentifier, IAuthenticator, IMetadataProvider)
def __init__(self, rememberer_name, config, saml_client, wayf, cache, def __init__(self, rememberer_name, config, saml_client, wayf, cache,
sid_store=None, discovery="", idp_query_param="", sid_store=None, discovery="", idp_query_param="",
sid_store_cert=None,): sid_store_cert=None,):
@@ -158,27 +141,24 @@ class SAML2Plugin(object):
def _get_post(self, environ): def _get_post(self, environ):
""" """
Get the posted information Get the posted information
:param environ: A dictionary with environment variables :param environ: A dictionary with environment variables
""" """
post_env = environ.copy() body= ''
post_env['QUERY_STRING'] = ''
_ = get_body(environ)
try: try:
post = cgi.FieldStorage( length= int(environ.get('CONTENT_LENGTH', '0'))
fp=environ['wsgi.input'], except ValueError:
environ=post_env, length= 0
keep_blank_values=True if length!=0:
) body = environ['wsgi.input'].read(length) # get the POST variables
except Exception, excp: environ['s2repoze.body'] = body # store the request body for later use by pysaml2
logger.debug("Exception (II): %s" % (excp,)) environ['wsgi.input'] = StringIO(body) # restore the request body as a stream so that everything seems untouched
raise
post = parse_qs(body) # parse the POST fields into a dict
logger.debug('identify post: %s' % (post,)) logger.debug('identify post: %s' % (post,))
return post return post
def _wayf_redirect(self, came_from): def _wayf_redirect(self, came_from):
@@ -190,8 +170,8 @@ class SAML2Plugin(object):
#noinspection PyUnusedLocal #noinspection PyUnusedLocal
def _pick_idp(self, environ, came_from): def _pick_idp(self, environ, came_from):
""" """
If more than one idp and if none is selected, I have to do wayf or If more than one idp and if none is selected, I have to do wayf or
disco disco
""" """
@@ -230,7 +210,7 @@ class SAML2Plugin(object):
detail='unknown ECP version') detail='unknown ECP version')
idps = self.metadata.with_descriptor("idpsso") idps = self.metadata.with_descriptor("idpsso")
logger.info("IdP URL: %s" % idps) logger.info("IdP URL: %s" % idps)
idp_entity_id = query = None idp_entity_id = query = None
@@ -290,7 +270,7 @@ class SAML2Plugin(object):
logger.info("Chosen IdP: '%s'" % idp_entity_id) logger.info("Chosen IdP: '%s'" % idp_entity_id)
return 0, idp_entity_id return 0, idp_entity_id
#### IChallenger #### #### IChallenger ####
#noinspection PyUnusedLocal #noinspection PyUnusedLocal
def challenge(self, environ, _status, _app_headers, _forget_headers): def challenge(self, environ, _status, _app_headers, _forget_headers):
@@ -320,7 +300,7 @@ class SAML2Plugin(object):
came_from = construct_came_from(environ) came_from = construct_came_from(environ)
environ["myapp.came_from"] = came_from environ["myapp.came_from"] = came_from
logger.debug("[sp.challenge] RelayState >> '%s'" % came_from) logger.debug("[sp.challenge] RelayState >> '%s'" % came_from)
# Am I part of a virtual organization or more than one ? # Am I part of a virtual organization or more than one ?
try: try:
vorg_name = environ["myapp.vo"] vorg_name = environ["myapp.vo"]
@@ -329,7 +309,7 @@ class SAML2Plugin(object):
vorg_name = _cli.vorg._name vorg_name = _cli.vorg._name
except AttributeError: except AttributeError:
vorg_name = "" vorg_name = ""
logger.info("[sp.challenge] VO: %s" % vorg_name) logger.info("[sp.challenge] VO: %s" % vorg_name)
# If more than one idp and if none is selected, I have to do wayf # If more than one idp and if none is selected, I have to do wayf
@@ -373,7 +353,7 @@ class SAML2Plugin(object):
req_id, msg_str = _cli.create_authn_request( req_id, msg_str = _cli.create_authn_request(
dest, vorg=vorg_name, sign=_cli.authn_requests_signed, dest, vorg=vorg_name, sign=_cli.authn_requests_signed,
message_id=_sid, extensions=extensions) message_id=_sid, extensions=extensions)
_sid = req_id _sid = req_id
else: else:
req_id, req = _cli.create_authn_request( req_id, req = _cli.create_authn_request(
dest, vorg=vorg_name, sign=False, extensions=extensions) dest, vorg=vorg_name, sign=False, extensions=extensions)
@@ -423,7 +403,7 @@ class SAML2Plugin(object):
logger.debug("Identity: %s" % identity) logger.debug("Identity: %s" % identity)
return identity return identity
def _eval_authn_response(self, environ, post, binding=BINDING_HTTP_POST): def _eval_authn_response(self, environ, post, binding=BINDING_HTTP_POST):
logger.info("Got AuthN response, checking..") logger.info("Got AuthN response, checking..")
logger.info("Outstanding: %s" % (self.outstanding_queries,)) logger.info("Outstanding: %s" % (self.outstanding_queries,))
@@ -432,18 +412,18 @@ class SAML2Plugin(object):
# Evaluate the response, returns a AuthnResponse instance # Evaluate the response, returns a AuthnResponse instance
try: try:
authresp = self.saml_client.parse_authn_request_response( authresp = self.saml_client.parse_authn_request_response(
post["SAMLResponse"], binding, self.outstanding_queries, post["SAMLResponse"][0], binding, self.outstanding_queries,
self.outstanding_certs) self.outstanding_certs)
except Exception, excp: except Exception, excp:
logger.exception("Exception: %s" % (excp,)) logger.exception("Exception: %s" % (excp,))
raise raise
session_info = authresp.session_info() session_info = authresp.session_info()
except TypeError, excp: except TypeError, excp:
logger.exception("Exception: %s" % (excp,)) logger.exception("Exception: %s" % (excp,))
return None return None
if session_info["came_from"]: if session_info["came_from"]:
logger.debug("came_from << %s" % session_info["came_from"]) logger.debug("came_from << %s" % session_info["came_from"])
try: try:
@@ -478,13 +458,13 @@ class SAML2Plugin(object):
"SAMLResponse" not in query and "SAMLRequest" not in query: "SAMLResponse" not in query and "SAMLRequest" not in query:
logger.debug('[identify] get or empty post') logger.debug('[identify] get or empty post')
return None return None
# if logger: # if logger:
# logger.info("ENVIRON: %s" % environ) # logger.info("ENVIRON: %s" % environ)
# logger.info("self: %s" % (self.__dict__,)) # logger.info("self: %s" % (self.__dict__,))
uri = environ.get('REQUEST_URI', construct_url(environ)) uri = environ.get('REQUEST_URI', construct_url(environ))
logger.debug('[sp.identify] uri: %s' % (uri,)) logger.debug('[sp.identify] uri: %s' % (uri,))
query = parse_dict_querystring(environ) query = parse_dict_querystring(environ)
@@ -495,15 +475,13 @@ class SAML2Plugin(object):
binding = BINDING_HTTP_REDIRECT binding = BINDING_HTTP_REDIRECT
else: else:
post = self._get_post(environ) post = self._get_post(environ)
if post.list is None:
post.list = []
binding = BINDING_HTTP_POST binding = BINDING_HTTP_POST
try: try:
logger.debug('[sp.identify] post keys: %s' % (post.keys(),)) logger.debug('[sp.identify] post keys: %s' % (post.keys(),))
except (TypeError, IndexError): except (TypeError, IndexError):
pass pass
try: try:
path_info = environ['PATH_INFO'] path_info = environ['PATH_INFO']
logout = False logout = False
@@ -514,7 +492,7 @@ class SAML2Plugin(object):
print("logout request received") print("logout request received")
try: try:
response = self.saml_client.handle_logout_request( response = self.saml_client.handle_logout_request(
post["SAMLRequest"], post["SAMLRequest"][0],
self.saml_client.users.subjects()[0], binding) self.saml_client.users.subjects()[0], binding)
environ['samlsp.pending'] = self._handle_logout(response) environ['samlsp.pending'] = self._handle_logout(response)
return {} return {}
@@ -536,7 +514,7 @@ class SAML2Plugin(object):
try: try:
if logout: if logout:
response = self.saml_client.parse_logout_request_response( response = self.saml_client.parse_logout_request_response(
post["SAMLResponse"], binding) post["SAMLResponse"][0], binding)
if response: if response:
action = self.saml_client.handle_logout_response( action = self.saml_client.handle_logout_response(
response) response)
@@ -572,8 +550,8 @@ class SAML2Plugin(object):
exception_trace("sp.identity", exc, logger) exception_trace("sp.identity", exc, logger)
environ["post.fieldstorage"] = post environ["post.fieldstorage"] = post
return {} return {}
if session_info: if session_info:
environ["s2repoze.sessioninfo"] = session_info environ["s2repoze.sessioninfo"] = session_info
return self._construct_identity(session_info) return self._construct_identity(session_info)
else: else:
@@ -596,12 +574,12 @@ class SAML2Plugin(object):
logger.debug("Issuers: %s" % _cli.users.sources(name_id)) logger.debug("Issuers: %s" % _cli.users.sources(name_id))
except KeyError: except KeyError:
pass pass
if "user" not in identity: if "user" not in identity:
identity["user"] = {} identity["user"] = {}
try: try:
(ava, _) = _cli.users.get_identity(name_id) (ava, _) = _cli.users.get_identity(name_id)
#now = time.gmtime() #now = time.gmtime()
logger.debug("[add_metadata] adds: %s" % ava) logger.debug("[add_metadata] adds: %s" % ava)
identity["user"].update(ava) identity["user"].update(ava)
except KeyError: except KeyError:
@@ -625,7 +603,7 @@ class SAML2Plugin(object):
if not identity["user"]: if not identity["user"]:
# remove cookie and demand re-authentication # remove cookie and demand re-authentication
pass pass
# used 2 times : one to get the ticket, the other to validate it # used 2 times : one to get the ticket, the other to validate it
@staticmethod @staticmethod
def _service_url(environ, qstr=None): def _service_url(environ, qstr=None):
@@ -635,7 +613,7 @@ class SAML2Plugin(object):
url = construct_url(environ) url = construct_url(environ)
return url return url
#### IAuthenticatorPlugin #### #### IAuthenticatorPlugin ####
#noinspection PyUnusedLocal #noinspection PyUnusedLocal
def authenticate(self, environ, identity=None): def authenticate(self, environ, identity=None):
if identity: if identity:
@@ -672,7 +650,7 @@ def make_plugin(remember_name=None, # plugin for remember
discovery="", discovery="",
idp_query_param="" idp_query_param=""
): ):
if saml_conf is "": if saml_conf is "":
raise ValueError( raise ValueError(
'must include saml_conf in configuration') 'must include saml_conf in configuration')