Updated to comply with changes in pysaml2.

This commit is contained in:
Roland Hedberg
2012-12-21 15:32:10 +01:00
parent 96abdd254b
commit a4761b3274
10 changed files with 1438 additions and 664 deletions

10
script/saml2c.py Executable file
View File

@@ -0,0 +1,10 @@
#!/usr/bin/env python
__author__ = 'rohe0002'
from idp_test import saml2int
from idp_test import SAML2client
cli = SAML2client(saml2int)
cli.run()

View File

@@ -20,7 +20,7 @@ from setuptools import setup
__author__ = 'rohe0002'
setup(
name="oic",
name="saml2test",
version="0.3.0",
description="SAML2 test tool",
author = "Roland Hedberg",
@@ -37,4 +37,5 @@ setup(
"beautifulsoup4"],
zip_safe=False,
scripts=["script/saml2c.py"]
)

View File

@@ -1,11 +1,39 @@
from importlib import import_module
import json
import argparse
import sys
import time
import logging
from saml2.config import SPConfig
from idp_test.base import FatalError
from idp_test.base import do_sequence
from idp_test.httpreq import HTTPC
#from saml2.config import Config
from saml2.mdstore import MetadataStore, MetaData
# Schemas supported
from saml2 import md
from saml2 import saml
from saml2.extension import mdui
from saml2.extension import idpdisc
from saml2.extension import dri
from saml2.extension import mdattr
from saml2.extension import ui
from saml2.metadata import entity_descriptor
import xmldsig
import xmlenc
SCHEMA = [ dri, idpdisc, md, mdattr, mdui, saml, ui, xmldsig, xmlenc]
__author__ = 'rolandh'
import traceback
logger = logging.getLogger("")
def exception_trace(tag, exc, log=None):
message = traceback.format_exception(*sys.exc_info())
if log:
@@ -15,15 +43,58 @@ def exception_trace(tag, exc, log=None):
print >> sys.stderr, "[%s] ExcList: %s" % (tag, "".join(message),)
print >> sys.stderr, "[%s] Exception: %s" % (tag, exc)
class SAML2(object):
client_args = ["client_id", "redirect_uris", "password"]
class Trace(object):
def __init__(self):
self.trace = []
self.start = time.time()
def __init__(self, operations_mod, client_class, msgfactory):
self.operations_mod = operations_mod
self.client_class = client_class
self.client = None
#self.trace = Trace()
self.msgfactory = msgfactory
def request(self, msg):
delta = time.time() - self.start
self.trace.append("%f --> %s" % (delta, msg))
def reply(self, msg):
delta = time.time() - self.start
self.trace.append("%f <-- %s" % (delta, msg))
def info(self, msg):
delta = time.time() - self.start
self.trace.append("%f %s" % (delta, msg))
def error(self, msg):
delta = time.time() - self.start
self.trace.append("%f [ERROR] %s" % (delta, msg))
def warning(self, msg):
delta = time.time() - self.start
self.trace.append("%f [WARNING] %s" % (delta, msg))
def __str__(self):
try:
return "\n".join([t.encode("utf-8") for t in self.trace])
except UnicodeDecodeError:
arr = []
for t in self.trace:
try:
arr.append(t.encode("utf-8"))
except UnicodeDecodeError:
arr.append(t)
return "\n".join(arr)
def clear(self):
self.trace = []
def __getitem__(self, item):
return self.trace[item]
def next(self):
for line in self.trace:
yield line
class SAML2client(object):
def __init__(self, operations):
self.trace = Trace()
self.operations = operations
self._parser = argparse.ArgumentParser()
self._parser.add_argument('-d', dest='debug', action='store_true',
@@ -31,35 +102,21 @@ class SAML2(object):
self._parser.add_argument('-v', dest='verbose', action='store_true',
help="Print runtime information")
self._parser.add_argument('-C', dest="ca_certs",
help="CA certs to use to verify HTTPS server certificates, if HTTPS is used and no server CA certs are defined then no cert verification is done")
help="CA certs to use to verify HTTPS server certificates, if HTTPS is used and no server CA certs are defined then no cert verification will be done")
self._parser.add_argument('-J', dest="json_config_file",
help="Script configuration")
self._parser.add_argument('-S', dest="sp_id", help="SP id")
self._parser.add_argument("-s", dest="list_sp_id", action="store_true",
help="List all the SP variants as a JSON object")
self._parser.add_argument('-m', dest="metadata", action='store_true',
help="Return the SP metadata")
self._parser.add_argument("-l", dest="list", action="store_true",
help="List all the test flows as a JSON object")
self._parser.add_argument("-H", dest="host", default="example.com",
help="Which host the script is running on, used to construct the key export URL")
self._parser.add_argument("flow", nargs="?", help="Which test flow to run")
self._parser.add_argument("oper", nargs="?", help="Which test to run")
self.args = None
self.pinfo = None
self.sequences = []
self.function_args = {}
self.signing_key = None
self.encryption_key = None
self.test_log = []
self.environ = {}
self._pop = None
def parse_args(self):
self.json_config= self.json_config_file()
try:
self.features = self.json_config["features"]
except KeyError:
self.features = {}
self.pinfo = self.provider_info()
self.client_conf(self.client_args)
self.interactions = None
self.entity_id = None
self.sp_config = None
def json_config_file(self):
if self.args.json_config_file == "-":
@@ -67,6 +124,34 @@ class SAML2(object):
else:
return json.loads(open(self.args.json_config_file).read())
def sp_configure(self, metadata_construction=False):
sys.path.insert(0, ".")
mod = import_module("config_file")
if self.args.sp_id is None:
if len(mod.CONFIG) == 1:
self.args.sp_id = mod.CONFIG.keys()[0]
else:
raise Exception("SP id undefined")
self.sp_config = SPConfig().load(mod.CONFIG[self.args.sp_id],
metadata_construction)
def setup(self):
self.json_config= self.json_config_file()
_jc = self.json_config
self.interactions = _jc["interaction"]
self.entity_id = _jc["entity_id"]
self.sp_configure()
metadata = MetadataStore(SCHEMA, self.sp_config.attribute_converters,
self.sp_config.xmlsec_binary)
metadata[0] = MetaData(SCHEMA, self.sp_config.attribute_converters,
_jc["metadata"])
self.sp_config.metadata = metadata
def test_summation(self, id):
status = 0
for item in self.test_log:
@@ -91,75 +176,45 @@ class SAML2(object):
def run(self):
self.args = self._parser.parse_args()
if self.args.list:
return self.operations()
if self.args.metadata:
return self.make_meta()
elif self.args.list_sp_id:
return self.list_conf_id()
elif self.args.list:
return self.list_operations()
else:
if not self.args.flow:
raise Exception("Missing flow specification")
self.args.flow = self.args.flow.strip("'")
self.args.flow = self.args.flow.strip('"')
if not self.args.oper:
raise Exception("Missing test case specification")
self.args.oper = self.args.oper.strip("'")
self.args.oper = self.args.oper.strip('"')
flow_spec = self.operations_mod.FLOWS[self.args.flow]
self.setup()
try:
try:
block = flow_spec["block"]
oper = self.operations.OPERATIONS[self.args.oper]
except KeyError:
block = {}
self.parse_args()
_seq = self.make_sequence()
interact = self.get_interactions()
try:
self.do_features(interact, _seq, block)
except Exception,exc:
exception_trace("do_features", exc)
_output = {"status": 4,
"tests": [{"status": 4,
"message":"Couldn't run testflow: %s" % exc,
"id": "verify_features",
"name": "Make sure you don't do things you shouldn't"}]}
#print >> sys.stdout, json.dumps(_output)
print >> sys.stderr, "Undefined testcase"
return
tests = self.get_test()
self.client.state = "STATE0"
testres, trace = do_sequence(self.sp_config, oper, HTTPC(),
self.trace, self.interactions,
entity_id=self.json_config["entity_id"])
self.test_log = testres
sum = self.test_summation(self.args.oper)
print >>sys.stdout, json.dumps(sum)
if sum["status"] > 1 or self.args.debug:
print >> sys.stderr, trace
except FatalError:
pass
except Exception, err:
print >> sys.stderr, self.trace
print err
exception_trace("RUN", err)
self.environ.update({"provider_info": self.pinfo,
"client": self.client})
try:
except_exception = flow_spec["except_exception"]
except KeyError:
except_exception = False
try:
if self.args.verbose:
print >> sys.stderr, "Set up done, running sequence"
testres, trace = run_sequence(self.client, _seq, self.trace,
interact, self.msgfactory,
self.environ, tests,
self.json_config["features"],
self.args.verbose, self.cconf,
except_exception)
self.test_log.extend(testres)
sum = self.test_summation(self.args.flow)
print >>sys.stdout, json.dumps(sum)
if sum["status"] > 1 or self.args.debug:
print >>sys.stderr, trace
except Exception, err:
#print >> sys.stderr, self.trace
print err
exception_trace("RUN", err)
#if self._pop is not None:
# self._pop.terminate()
if "keyprovider" in self.environ and self.environ["keyprovider"]:
# os.kill(self.environ["keyprovider"].pid, signal.SIGTERM)
self.environ["keyprovider"].terminate()
def operations(self):
def list_operations(self):
lista = []
for key,val in self.operations_mod.FLOWS.items():
for key,val in self.operations.OPERATIONS.items():
item = {"id": key,
"name": val["name"],}
try:
@@ -178,112 +233,17 @@ class SAML2(object):
pass
lista.append(item)
print json.dumps(lista)
def provider_info(self):
# Should provide a Metadata class
res = {}
_jc = self.json_config["provider"]
def _get_operation(self, operation):
return self.operations.OPERATIONS[operation]
# Backward compatible
if "endpoints" in _jc:
try:
for endp, url in _jc["endpoints"].items():
res[endp] = url
except KeyError:
pass
for key in ProviderConfigurationResponse.c_param.keys():
try:
res[key] = _jc[key]
except KeyError:
pass
return res
def do_features(self, *args):
pass
def export(self):
pass
def client_conf(self, cprop):
if self.args.ca_certs:
self.client = self.client_class(ca_certs=self.args.ca_certs)
else:
try:
self.client = self.client_class(
ca_certs=self.json_config["ca_certs"])
except (KeyError, TypeError):
self.client = self.client_class()
#self.client.http_request = self.client.http.crequest
# set the endpoints in the Client from the provider information
# If they are statically configured, if dynamic it happens elsewhere
for key, val in self.pinfo.items():
if key.endswith("_endpoint"):
setattr(self.client, key, val)
# Client configuration
self.cconf = self.json_config["client"]
# replace pattern with real value
_h = self.args.host
self.cconf["redirect_uris"] = [p % _h for p in self.cconf["redirect_uris"]]
try:
self.client.client_prefs = self.cconf["preferences"]
except KeyError:
pass
# set necessary information in the Client
for prop in cprop:
try:
setattr(self.client, prop, self.cconf[prop])
except KeyError:
pass
def make_sequence(self):
# Whatever is specified on the command line takes precedences
if self.args.flow:
sequence = flow2sequence(self.operations_mod, self.args.flow)
elif self.json_config and "flow" in self.json_config:
sequence = flow2sequence(self.operations_mod,
self.json_config["flow"])
else:
sequence = None
return sequence
def get_interactions(self):
interactions = []
if self.json_config:
try:
interactions = self.json_config["interaction"]
except KeyError:
pass
if self.args.interactions:
_int = self.args.interactions.replace("\'", '"')
if interactions:
interactions.update(json.loads(_int))
else:
interactions = json.loads(_int)
return interactions
def get_test(self):
if self.args.flow:
flow = self.operations_mod.FLOWS[self.args.flow]
elif self.json_config and "flow" in self.json_config:
flow = self.operations_mod.FLOWS[self.json_config["flow"]]
else:
flow = None
try:
return flow["tests"]
except KeyError:
return []
def make_meta(self):
self.sp_configure(True)
print entity_descriptor(self.sp_config)
def list_conf_id(self):
sys.path.insert(0, ".")
mod = import_module("config_file")
_res = dict([(key, cnf["description"]) for key, cnf in mod.CONFIG.items()])
print json.dumps(_res)

View File

@@ -1,392 +1,281 @@
#!/usr/bin/env python
from check import ExpectedError
from check import factory
import base64
import inspect
from saml2 import BINDING_HTTP_REDIRECT
from saml2 import BINDING_HTTP_POST
from saml2 import BINDING_SOAP
from saml2.client import Saml2Client
#from idp_test.check import ExpectedError
from saml2.mdstore import REQ2SRV
from saml2.pack import http_redirect_message
from saml2.s_utils import rndstr
from idp_test.check import factory
from idp_test.check import STATUSCODE
from idp_test.interaction import Operation
from idp_test.interaction import pick_interaction
__author__ = 'rohe0002'
import time
import cookielib
from bs4 import BeautifulSoup
class FatalError(Exception):
pass
class Trace(object):
def __init__(self):
self.trace = []
self.start = time.time()
def form_post(request, relay_state):
return "SAMLRequest=%s&RelayState=%s" % (base64.b64encode("%s" % request),
relay_state)
def request(self, msg):
delta = time.time() - self.start
self.trace.append("%f --> %s" % (delta, msg))
def reply(self, msg):
delta = time.time() - self.start
self.trace.append("%f <-- %s" % (delta, msg))
def info(self, msg):
delta = time.time() - self.start
self.trace.append("%f %s" % (delta, msg))
def error(self, msg):
delta = time.time() - self.start
self.trace.append("%f [ERROR] %s" % (delta, msg))
def warning(self, msg):
delta = time.time() - self.start
self.trace.append("%f [WARNING] %s" % (delta, msg))
def __str__(self):
return "\n". join([t.encode("utf-8") for t in self.trace])
def clear(self):
self.trace = []
def __getitem__(self, item):
return self.trace[item]
def next(self):
for line in self.trace:
yield line
def flow2sequence(operations, item):
flow = operations.FLOWS[item]
return [operations.PHASES[phase] for phase in flow["sequence"]]
def endpoint(client, base):
for _endp in client._endpoints:
if getattr(client, _endp) == base:
return True
return False
def check_severity(stat):
def check_severity(stat, trace):
if stat["status"] >= 4:
trace.error("WHERE: %s" % stat["id"])
trace.error("STATUS:%s" % STATUSCODE[stat["status"]])
try:
trace.error("HTTP STATUS: %s" % stat["http_status"])
except KeyError:
pass
try:
trace.error("INFO: %s" % stat["message"])
except KeyError:
pass
raise FatalError
def pick_interaction(interactions, _base="", content="", req=None):
unic = content
if content:
_bs = BeautifulSoup(content)
def intermit(client, response, httpc, environ, trace, cjar, interaction,
test_output, features=None):
if response.status_code >= 400:
done = True
else:
_bs = None
done = False
for interaction in interactions:
_match = 0
for attr, val in interaction["matches"].items():
if attr == "url":
if val == _base:
_match += 1
elif attr == "title":
if _bs is None:
break
if _bs.title is None:
break
if val in _bs.title.contents:
_match += 1
elif attr == "content":
if unic and val in unic:
_match += 1
elif attr == "class":
if req and val == req:
_match += 1
url = response.url
content = response.text
if _match == len(interaction["matches"]):
return interaction
while not done:
while response.status_code in [302, 301, 303]:
url = response.headers["location"]
raise KeyError("No interaction matched")
trace.reply("REDIRECT TO: %s" % url)
# If back to me
for_me = False
acs = client.config.getattr("endpoints",
"sp")["assertion_consumer_service"]
for redirect_uri in acs:
if url.startswith(redirect_uri):
# Back at the RP
environ["client"].cookiejar = cjar["rp"]
for_me=True
ORDER = ["url", "response", "content"]
if for_me:
done = True
break
else:
try:
response = httpc.request(url, "GET", trace=trace)
except Exception, err:
raise FatalError("%s" % err)
def run_sequence(client, sequence, trace, interaction, msgfactory,
environ=None, tests=None, features=None, verbose=False,
cconf=None, except_exception=None):
item = []
response = None
content = None
url = ""
content = response.text
trace.reply("CONTENT: %s" % content)
environ.update({"url": url, "response": response,
"content":content})
check = factory("check-http-response")()
stat = check(environ, test_output)
check_severity(stat, trace)
if done:
break
_base = url.split("?")[0]
try:
_spec = pick_interaction(interaction, _base, content)
except KeyError:
chk = factory("interaction-needed")()
chk(environ, test_output)
raise FatalError()
if len(_spec) > 2:
trace.info(">> %s <<" % _spec["page-type"])
if _spec["page-type"] == "login":
environ["login"] = content
_op = Operation(_spec["control"])
try:
response = _op(httpc, environ, trace, url, response, content,
features)
if isinstance(response, dict):
return response
content = response.text
environ.update({"url": url, "response": response,
"content":content})
check = factory("check-http-response")()
stat = check(environ, test_output)
check_severity(stat, trace)
except FatalError:
raise
except Exception, err:
environ["exception"] = err
chk = factory("exception")()
chk(environ, test_output)
return response
def do_sequence(config, oper, httpc, trace, interaction, entity_id,
features=None):
"""
:param config: SP configuration
:param oper: A dictionary describing the operations to perform
:param httpc: A HTTP Client instance
:param trace: A Trace instance that keep all the trace information
:param interaction: A list of interaction definitions
:param entity_id: The entity_id of the IdP
:param features: ?
:returns: A 2-tuple (testoutput, tracelog)
"""
client = Saml2Client(config)
test_output = []
_keystore = client.keystore
features = features or {}
if client.metadata.entities_descr["-"]:
environ = {"metadata": client.metadata.entities_descr["-"]}
else:
environ = {"metadata": client.metadata.entity_descr["-"]}
cjar = {"browser": cookielib.CookieJar(),
"rp": cookielib.CookieJar(),
"service": cookielib.CookieJar()}
environ["sequence"] = sequence
environ["cis"] = []
environ["trace"] = trace
environ["responses"] = []
environ["FatalError"] = False
for op in oper["sequence"]:
output = do_query(client, op(), httpc, trace, interaction, entity_id,
environ, cjar, features)
test_output.extend(output)
if environ["FatalError"]:
break
return test_output, "%s" % trace
def do_query(client, oper, httpc, trace, interaction, entity_id, environ, cjar,
features=None):
"""
:param client: A SAML2 client
:param oper: A Request class instance
:param httpc: A HTTP Client instance
:param trace: A Trace instance that keep all the trace information
:param interaction: A list of interaction definitions
:param entity_id: The entity_id of the IdP
:param environ: Local environment
:param features: ?
:returns: A 2-tuple (testoutput, tracelog)
"""
oper.setup(environ)
query = oper.request
args = oper.args
args["entity_id"] = entity_id
test_output = []
try:
for creq in sequence:
req = creq()
cfunc = getattr(client, "create_%s" % req.request)
if trace:
trace.info(70*"=")
for test in oper.tests["pre"]:
chk = test()
stat = chk(environ, test_output)
try:
_pretests = req.tests["pre"]
for test in _pretests:
chk = test()
stat = chk(environ, test_output)
check_severity(stat)
except KeyError:
pass
check_severity(stat, trace)
except FatalError:
environ["FatalError"] = True
raise
except KeyError:
pass
httpc.cookiejar = cjar["browser"]
locations = getattr(client.metadata, REQ2SRV[query])(args["entity_id"],
args["binding"])
relay_state = rndstr()
_response_func = getattr(client, "%s_response" % query)
response_args = {}
qargs = args.copy()
qfunc = getattr(client, "create_%s" % query)
# remove args the create function can't handle
fargs = inspect.getargspec(qfunc).args
for arg in qargs.keys():
if arg not in fargs:
del qargs[arg]
resp = None
for loc in locations:
qargs["destination"] = loc
req = qfunc(**qargs)
environ["request"] = req
_req_str = "%s" % req
trace.info("SAML Request: %s" % _req_str)
# depending on binding send the query
if args["binding"] is BINDING_HTTP_REDIRECT:
(head, _body) = http_redirect_message(_req_str, loc, relay_state)
res = httpc.request(head[0][1], "GET")
response_args["outstanding"] = {req.id: "/"}
# head should contain a redirect
# deal with redirect, should in the end give me a response
try:
response = cfunc(**req.args)
response = intermit(client, res, httpc, environ, trace, cjar,
interaction, test_output, features)
except FatalError:
environ["FatalError"] = True
response = None
if isinstance(response, dict):
assert relay_state == response["RelayState"]
elif args["binding"] is BINDING_HTTP_POST:
body = form_post(_req_str, relay_state)
res = httpc.request(loc, "POST", data=body)
response_args["outstanding"] = {req.id: "/"}
# head should contain a redirect
# deal with redirect, should in the end give me a response
try:
response = intermit(client, res, httpc, environ, trace, cjar,
interaction, test_output, features)
except FatalError:
environ["FatalError"] = True
response = None
elif args["binding"] is BINDING_SOAP:
response = client.send_using_soap(_req_str, loc,
client.config.key_file,
client.config.cert_file,
ca_certs=client.config.ca_certs)
else:
response = None
if response:
try:
_resp = _response_func(response, **response_args)
environ["response"] = _resp
trace.info("SAML Response: %s" % _resp)
try:
for test in req.tests["post"]:
if isinstance(test, tuple):
test, kwargs = test
else:
kwargs = {}
chk = test(**kwargs)
for test in oper.tests["post"]:
chk = test()
stat = chk(environ, test_output)
check_severity(stat)
if isinstance(chk, ExpectedError):
item.append(stat["temp"])
del stat["temp"]
url = None
break
check_severity(stat, trace)
except KeyError:
pass
except FatalError:
raise
except FatalError:
environ["FatalError"] = True
break
except Exception, err:
environ["exception"] = err
chk = factory("exception")()
chk(environ, test_output)
raise FatalError()
if not response:
continue
if response.status_code >= 400:
done = True
elif url:
done = False
else:
done = True
while not done:
while response.status_code in [302, 301, 303]:
url = response.headers["location"]
trace.reply("REDIRECT TO: %s" % url)
# If back to me
for_me = False
for redirect_uri in client.redirect_uris:
if url.startswith(redirect_uri):
# Back at the RP
environ["client"].cookiejar = cjar["rp"]
for_me=True
if for_me:
done = True
break
else:
try:
part = do_request(client, url, "GET", trace=trace)
except Exception, err:
raise FatalError("%s" % err)
environ.update(dict(zip(ORDER, part)))
(url, response, content) = part
check = factory("check-http-response")()
stat = check(environ, test_output)
check_severity(stat)
if done:
break
_base = url.split("?")[0]
try:
_spec = pick_interaction(interaction, _base, content)
except KeyError:
if creq.method == "POST":
break
elif not req.request in ["AuthorizationRequest",
"OpenIDRequest"]:
break
else:
try:
_check = getattr(req, "interaction_check")
except AttributeError:
_check = None
if _check:
chk = factory("interaction-check")()
chk(environ, test_output)
raise FatalError()
else:
chk = factory("interaction-needed")()
chk(environ, test_output)
raise FatalError()
if len(_spec) > 2:
trace.info(">> %s <<" % _spec["page-type"])
if _spec["page-type"] == "login":
environ["login"] = content
_op = Operation(_spec["control"])
try:
part = _op(environ, trace, url, response, content, features)
environ.update(dict(zip(ORDER, part)))
(url, response, content) = part
check = factory("check-http-response")()
stat = check(environ, test_output)
check_severity(stat)
except FatalError:
raise
except Exception, err:
environ["exception"] = err
chk = factory("exception")()
chk(environ, test_output)
raise FatalError
# if done:
# break
info = None
qresp = None
resp_type = resp.type
if response:
try:
ctype = response.headers["content-type"]
if ctype == "application/jwt":
resp_type = "jwt"
except (AttributeError, TypeError):
pass
if response.status_code >= 400:
pass
elif not url:
if isinstance(content, Message):
qresp = content
elif response.status_code == 200:
info = content
elif resp.where == "url" or response.status_code == 302:
try:
info = response.headers["location"]
resp_type = "urlencoded"
except KeyError:
try:
_check = getattr(req, "interaction_check", None)
except AttributeError:
_check = None
if _check:
chk = factory("interaction-check")()
chk(environ, test_output)
raise FatalError()
else:
chk = factory("missing-redirect")()
stat = chk(environ, test_output)
check_severity(stat)
else:
check = factory("check_content_type_header")()
stat = check(environ, test_output)
check_severity(stat)
info = content
if info and resp.response:
if isinstance(resp.response, basestring):
response = msgfactory(resp.response)
else:
response = resp.response
chk = factory("response-parse")()
environ["response_type"] = response.__name__
environ["responses"].append((response, info))
try:
qresp = client.parse_response(response, info, resp_type,
client.state,
keystore=_keystore,
client_id=client.client_id,
scope="openid")
if trace and qresp:
trace.info("[%s]: %s" % (qresp.type(),
qresp.to_dict()))
item.append(qresp)
environ["response_message"] = qresp
err = None
except Exception, err:
environ["exception"] = "%s" % err
qresp = None
if err and except_exception:
if isinstance(err, except_exception):
trace.info("Got expected exception: %s [%s]" % (err,
err.__class__.__name__))
else:
raise
else:
stat = chk(environ, test_output)
check_severity(stat)
if qresp:
try:
for test in resp.tests["post"]:
if isinstance(test, tuple):
test, kwargs = test
else:
kwargs = {}
chk = test(**kwargs)
stat = chk(environ, test_output)
check_severity(stat)
except KeyError:
pass
resp(environ, qresp)
if tests is not None:
environ["item"] = item
for test, args in tests:
if isinstance(test, basestring):
chk = factory(test)(**args)
else:
chk = test(**args)
try:
check_severity(chk(environ, test_output))
except Exception, err:
raise FatalError("%s" % err)
except FatalError:
pass
except Exception, err:
environ["exception"] = err
chk = factory("exception")()
chk(environ, test_output)
return test_output, "%s" % trace
def run_sequences(client, sequences, trace, interaction,
verbose=False):
for sequence, endpoints, fid in sequences:
# clear cookie cache
client.grant.clear()
try:
client.http.cookiejar.clear()
except AttributeError:
pass
err = run_sequence(client, sequence, trace, interaction, verbose)
if err:
print "%s - FAIL" % fid
print
if not verbose:
print trace
else:
print "%s - OK" % fid
trace.clear()
return test_output

View File

@@ -1,5 +1,11 @@
import inspect
import sys
import traceback
from saml2.md import EntitiesDescriptor
from saml2.saml import NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT
from saml2.saml import NAME_FORMAT_URI
from saml2.sigver import cert_from_key_info
from saml2.sigver import key_from_key_value
__author__ = 'rolandh'
@@ -67,6 +73,18 @@ class Other(CriticalError):
""" Other error """
msg = "Other error"
class WrapException(CriticalError):
"""
A runtime exception
"""
id = "exception"
msg = "Test tool exception"
def _func(self, environ=None):
self._status = self.status
self._message = traceback.format_exception(*sys.exc_info())
return {}
class CheckHTTPResponse(CriticalError):
"""
Checks that the HTTP response status is within the 200 or 300 range
@@ -76,36 +94,241 @@ class CheckHTTPResponse(CriticalError):
def _func(self, environ):
_response = environ["response"]
_content = environ["content"]
res = {}
if _response.status_code >= 400 :
self._status = self.status
self._message = self.msg
# if CONT_JSON in _response.headers["content-type"]:
# try:
# err = ErrorResponse().deserialize(_content, "json")
# self._message = err.to_json()
# except Exception:
# res["content"] = _content
# else:
# res["content"] = _content
res["url"] = environ["url"]
res["http_status"] = _response.status_code
else:
# might still be an error message
try:
# err = ErrorResponse().deserialize(_content, "json")
# err.verify()
# self._message = err.to_json()
self._status = self.status
except Exception:
pass
res["url"] = environ["url"]
res["content"] = _response.text
return res
class CheckSaml2IntMetaData(Check):
"""
Checks that the Metadata follows the profile
"""
id = "check-saml2int-metadata"
msg = "Metadata error"
def verify_key_info(self, ki):
# key_info
# one or more key_value and/or x509_data.X509Certificate
try:
assert ki.key_value or ki.x509_data
except AssertionError:
self._message = "Missing KeyValue or X509Data.X509Certificate"
self._status = CRITICAL
return False
xkeys = cert_from_key_info(ki)
vkeys = key_from_key_value(ki)
if xkeys and vkeys:
# verify that it's the same keys
pass
return True
def verify_key_descriptor(self, kd):
# key_info
if not self.verify_key_info(kd.key_info):
return False
# use
if kd.use:
try:
assert kd.use in ["encryption", "signing"]
except AssertionError:
self._message = "Unknown use specification: '%s'" % kd.use.text
self._status = CRITICAL
return False
return True
def _func(self, environ):
if isinstance(environ["metadata"], EntitiesDescriptor):
ed = environ["metadata"].entity_descriptor[0]
else:
ed = environ["metadata"]
res = {}
assert len(ed.idpsso_descriptor)
idpsso = ed.idpsso_descriptor[0]
for kd in idpsso.key_descriptor:
if self.verify_key_descriptor(kd) == False:
return res
# contact person
if not idpsso.contact_person:
self._message = "Metadata should contain contact person information"
self._status = WARNING
return res
else:
item = {"support": False, "technical": False}
for contact in idpsso.contact_person:
try:
item[contact.contact_type] = True
except KeyError:
pass
if item["support"] and item["technical"]:
pass
elif not item["support"] and not item["technical"]:
self._message = "Missing technical and support contact information"
self._status = WARNING
elif item["support"]:
self._message = "Missing technical contact information"
self._status = WARNING
elif item["technical"]:
self._message = "Missing support contact information"
self._status = WARNING
if self._message:
return res
# NameID format
if not idpsso.nameid_format:
self._message = "Metadata should specify NameID format support"
self._status = WARNING
return res
else:
# should support Transient
item = {NAMEID_FORMAT_TRANSIENT:False}
for format in idpsso.nameid_format:
try:
item[format] = True
except KeyError:
pass
if not item[NAMEID_FORMAT_TRANSIENT]:
self._message = "IdP should support Transient NameID Format"
self._status = WARNING
return res
return res
class CheckSaml2IntAttributes(Check):
"""
Any <saml2:Attribute> elements exchanged via any SAML 2.0 messages,
assertions, or metadata MUST contain a NameFormat of
urn:oasis:names:tc:SAML:2.0:attrname-format:uri.
"""
id = "check-saml2int-attributes"
msg = "Attribute error"
def _func(self, environ):
response = environ["response"]
try:
opaque_identifier = environ["opaque_identifier"]
except KeyError:
opaque_identifier = False
try:
name_format_not_specified = environ["name_format_not_specified"]
except KeyError:
name_format_not_specified = False
res = {}
# should be a list but isn't
#assert len(response.assertion) == 1
assertion = response.assertion
assert len(assertion.authn_statement) == 1
assert len(assertion.attribute_statement) < 2
if assertion.attribute_statement:
atrstat = assertion.attribute_statement[0]
for attr in atrstat.attribute:
try:
assert attr.name_format == NAME_FORMAT_URI
except AssertionError:
self._message = "Attribute name format error"
self._status = CRITICAL
return res
try:
assert attr.name.startswith("urn:oid")
except AssertionError:
self._message = "Attribute name should be an OID"
self._status = CRITICAL
return res
assert not assertion.subject.encrypted_id
assert not assertion.subject.base_id
if opaque_identifier:
try:
assert assertion.subject.name_id.format == NAMEID_FORMAT_PERSISTENT
except AssertionError:
self._message = "NameID format should be TRANSIENT"
self._status = WARNING
if name_format_not_specified:
try:
assert assertion.subject.name_id.format == NAMEID_FORMAT_TRANSIENT
except AssertionError:
self._message = "NameID format should be TRANSIENT"
self._status = WARNING
return res
class CheckSubjectNameIDFormat(Check):
"""
The <NameIDPolicy> element tailors the name identifier in the subjects of
assertions resulting from an <AuthnRequest>.
When this element is used, if the content is not understood by or acceptable
to the identity provider, then a <Response> message element MUST be
returned with an error <Status>, and MAY contain a second-level
<StatusCode> of urn:oasis:names:tc:SAML:2.0:status:InvalidNameIDPolicy.
If the Format value is omitted or set to urn:oasis:names:tc:SAML:2.0:nameid-
format:unspecified, then the identity provider is free to return any kind
of identifier, subject to any additional constraints due to the content of
this element or the policies of the identity provider or principal.
"""
id = "check-saml2int-attributes"
msg = "Attribute error"
def _func(self, environ):
response = environ["response"]
request = environ["request"]
res ={}
if request.name_id_policy:
format = request.name_id_policy.format
sp_name_qualifier = request.name_id_policy.sp_name_qualifier
subj = response.assertion.subject
try:
assert subj.name_id.format == format
if sp_name_qualifier:
assert subj.name_id.sp_name_qualifier == sp_name_qualifier
except AssertionError:
self._message = "The IdP returns wrong NameID format"
self._status = CRITICAL
return res
class CheckLogoutSupport(Check):
id = "check-logout-support"
msg = "Does not support logout"
def _func(self, environ):
if isinstance(environ["metadata"], EntitiesDescriptor):
ed = environ["metadata"].entity_descriptor[0]
else:
ed = environ["metadata"]
assert len(ed.idpsso_descriptor)
idpsso = ed.idpsso_descriptor[0]
try:
assert idpsso.single_logout_service
except AssertionError:
self._message = self.msg
self._status = CRITICAL
return {}
def factory(id):
for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj):

137
src/idp_test/httpreq.py Normal file
View File

@@ -0,0 +1,137 @@
from Cookie import SimpleCookie
import cookielib
import copy
import logging
import time
import requests
logger = logging.getLogger(__name__)
__author__ = 'rolandh'
# =============================================================================
ATTRS = {"version":None,
"name":"",
"value": None,
"port": None,
"port_specified": False,
"domain": "",
"domain_specified": False,
"domain_initial_dot": False,
"path": "",
"path_specified": False,
"secure": False,
"expires": None,
"discard": True,
"comment": None,
"comment_url": None,
"rest": "",
"rfc2109": True}
PAIRS = {
"port": "port_specified",
"domain": "domain_specified",
"path": "path_specified"
}
def _since_epoch(cdate):
# date format 'Wed, 06-Jun-2012 01:34:34 GMT'
if len(cdate) <= 5:
return 0
cdate = cdate[5:-4]
try:
t = time.strptime(cdate, "%d-%b-%Y %H:%M:%S")
except ValueError:
t = time.strptime(cdate, "%d-%b-%y %H:%M:%S")
return int(time.mktime(t))
class HTTPC(object):
def __init__(self, ca_certs=None):
self.request_args = {"allow_redirects": False,}
#self.cookies = cookielib.CookieJar()
self.cookies = {}
self.cookiejar = cookielib.CookieJar()
if ca_certs:
self.request_args["verify"] = True
else:
self.request_args["verify"] = False
def _cookies(self):
cookie_dict = {}
for _, a in list(self.cookiejar._cookies.items()):
for _, b in list(a.items()):
for cookie in list(b.values()):
# print cookie
cookie_dict[cookie.name] = cookie.value
return cookie_dict
def set_cookie(self, kaka, request):
"""sets a cookie in a Cookie jar based on a set-cookie header line"""
# default rfc2109=False
# max-age, httponly
for cookie_name, morsel in kaka.items():
std_attr = ATTRS.copy()
std_attr["name"] = cookie_name
_tmp = morsel.coded_value
if _tmp.startswith('"') and _tmp.endswith('"'):
std_attr["value"] = _tmp[1:-1]
else:
std_attr["value"] = _tmp
std_attr["version"] = 0
# copy attributes that have values
for attr in morsel.keys():
if attr in ATTRS:
if morsel[attr]:
if attr == "expires":
std_attr[attr]=_since_epoch(morsel[attr])
else:
std_attr[attr]=morsel[attr]
elif attr == "max-age":
if morsel["max-age"]:
std_attr["expires"] = _since_epoch(morsel["max-age"])
for att, set in PAIRS.items():
if std_attr[att]:
std_attr[set] = True
if std_attr["domain"] and std_attr["domain"].startswith("."):
std_attr["domain_initial_dot"] = True
if morsel["max-age"] is 0:
try:
self.cookiejar.clear(domain=std_attr["domain"],
path=std_attr["path"],
name=std_attr["name"])
except ValueError:
pass
else:
new_cookie = cookielib.Cookie(**std_attr)
self.cookiejar.set_cookie(new_cookie)
def request(self, url, method="GET", trace=None, **kwargs):
_kwargs = copy.copy(self.request_args)
if kwargs:
_kwargs.update(kwargs)
if self.cookiejar:
_kwargs["cookies"] = self._cookies()
if trace:
trace.info("SENT COOKIEs: %s" % (_kwargs["cookies"],))
r = requests.request(method, url, **_kwargs)
try:
if trace:
trace.info("RECEIVED COOKIEs: %s" % (r.headers["set-cookie"],))
self.set_cookie(SimpleCookie(r.headers["set-cookie"]), r)
except AttributeError, err:
pass
return r

379
src/idp_test/interaction.py Normal file
View File

@@ -0,0 +1,379 @@
__author__ = 'rohe0002'
import json
from urlparse import urlparse
from bs4 import BeautifulSoup
from mechanize import ParseResponseEx
from mechanize._form import ControlNotFoundError, AmbiguityError
from mechanize._form import ListControl
def pick_interaction(interactions, _base="", content="", req=None):
unic = content
if content:
_bs = BeautifulSoup(content)
else:
_bs = None
for interaction in interactions:
_match = 0
for attr, val in interaction["matches"].items():
if attr == "url":
if val == _base:
_match += 1
elif attr == "title":
if _bs is None:
break
if _bs.title is None:
break
if val in _bs.title.contents:
_match += 1
elif attr == "content":
if unic and val in unic:
_match += 1
elif attr == "class":
if req and val == req:
_match += 1
if _match == len(interaction["matches"]):
return interaction
raise KeyError("No interaction matched")
class FlowException(Exception):
def __init__(self, function="", content="", url=""):
Exception.__init__(self)
self.function = function
self.content = content
self.url = url
def __str__(self):
return json.dumps(self.__dict__)
class RResponse():
"""
A Response class that behaves in the way that mechanize expects it.
Links to a requests.Response
"""
def __init__(self, resp):
self._resp = resp
self.index = 0
self.text = resp.text
if isinstance(self.text, unicode):
if resp.encoding == "UTF-8":
self.text = self.text.encode("utf-8")
else:
self.text = self.text.encode("latin-1")
self._len = len(self.text)
self.url = str(resp.url)
self.statuscode = resp.status_code
def geturl(self):
return self._resp.url
def __getitem__(self, item):
try:
return getattr(self._resp, item)
except AttributeError:
return getattr(self._resp.headers, item)
def __getattribute__(self, item):
try:
return getattr(self._resp, item)
except AttributeError:
return getattr(self._resp.headers, item)
def read(self, size=0):
"""
Read from the content of the response. The class remembers what has
been read so it's possible to read small consecutive parts of the
content.
:param size: The number of bytes to read
:return: Somewhere between zero and 'size' number of bytes depending
on how much it left in the content buffer to read.
"""
if size:
if self._len < size:
return self.text
else:
if self._len == self.index:
part = None
elif self._len - self.index < size:
part = self.text[self.index:]
self.index = self._len
else:
part = self.text[self.index:self.index+size]
self.index += size
return part
else:
return self.text
def pick_form(response, url=None, **kwargs):
"""
Picks which form in a web-page that should be used
:param response: A HTTP request response. A DResponse instance
:param content: The HTTP response content
:param url: The url the request was sent to
:return: The picked form or None of no form matched the criteria.
"""
_txt = response.text
forms = ParseResponseEx(response)
if not forms:
raise FlowException(content=response.text, url=url)
#if len(forms) == 1:
# return forms[0]
#else:
_form = None
# ignore the first form, because I use ParseResponseEx which adds
# one form at the top of the list
forms = forms[1:]
if len(forms) == 1:
_form = forms[0]
else:
if "pick" in kwargs:
_dict = kwargs["pick"]
for form in forms:
if _form:
break
for key, _ava in _dict.items():
if key == "form":
_keys = form.attrs.keys()
for attr, val in _ava.items():
if attr in _keys and val == form.attrs[attr]:
_form = form
elif key == "control":
prop = _ava["id"]
_default = _ava["value"]
try:
orig_val = form[prop]
if isinstance(orig_val, basestring):
if orig_val == _default:
_form = form
elif _default in orig_val:
_form = form
except KeyError:
pass
except ControlNotFoundError:
pass
elif key == "method":
if form.method == _ava:
_form = form
else:
_form = None
if not _form:
break
elif "index" in kwargs:
_form = forms[int(kwargs["index"])]
return _form
def do_click(httpc, form, **kwargs):
"""
Emulates the user clicking submit on a form.
:param httpc: The Client instance
:param form: The form that should be submitted
:return: What do_request() returns
"""
if "click" in kwargs:
request=None
_name = kwargs["click"]
try:
_ = form.find_control(name=_name)
request = form.click(name=_name)
except AmbiguityError:
# more than one control with that name
_val = kwargs["set"][_name]
_nr = 0
while True:
try:
cntrl = form.find_control(name=_name, nr=_nr)
if cntrl.value == _val:
request = form.click(name=_name, nr=_nr)
break
else:
_nr += 1
except ControlNotFoundError:
raise Exception("No submit control with the name='%s' and "
"value='%s' could be found" % (_name,
_val))
else:
request = form.click()
headers = {}
for key, val in request.unredirected_hdrs.items():
headers[key] = val
url = request._Request__original
try:
_trace = kwargs["_trace_"]
except KeyError:
_trace = False
if form.method == "POST":
return httpc.request(url, "POST", data=request.data, headers=headers,
trace=_trace)
else:
return httpc.request(url, "GET", headers=headers, trace=_trace)
def select_form(httpc, orig_response, **kwargs):
"""
Pick a form on a web page, possibly enter some information and submit
the form.
:param httpc: A HTTP client instance
:param orig_response: The original response (as returned by requests)
:return: The response do_click() returns
"""
response = RResponse(orig_response)
try:
_url = response.url
except KeyError:
_url = kwargs["location"]
form = pick_form(response, _url, **kwargs)
#form.backwards_compatible = False
if not form:
raise Exception("Can't pick a form !!")
if "set" in kwargs:
for key, val in kwargs["set"].items():
if key.startswith("_"):
continue
if "click" in kwargs and kwargs["click"] == key:
continue
try:
form[key] = val
except ControlNotFoundError:
pass
except TypeError:
cntrl = form.find_control(key)
if isinstance(cntrl, ListControl):
form[key] = [val]
else:
raise
return do_click(httpc, form, **kwargs)
#noinspection PyUnusedLocal
def chose(httpc, orig_response, path, **kwargs):
"""
Sends a HTTP GET to a url given by the present url and the given
relative path.
:param orig_response: The original response
:param content: The content of the response
:param path: The relative path to add to the base URL
:return: The response do_click() returns
"""
try:
_trace = kwargs["trace"]
except KeyError:
_trace = False
if not path.startswith("http"):
try:
_url = orig_response.url
except KeyError:
_url = kwargs["location"]
part = urlparse(_url)
url = "%s://%s%s" % (part[0], part[1], path)
else:
url = path
return httpc.request(url, "GET", trace=_trace)
#return resp, ""
def post_form(httpc, orig_response, **kwargs):
"""
The same as select_form but with no possibility of change the content
of the form.
:param httpc: A HTTP Client instance
:param orig_response: The original response (as returned by requests)
:param content: The content of the response
:return: The response do_click() returns
"""
response = RResponse(orig_response)
form = pick_form(response, **kwargs)
return do_click(httpc, form, **kwargs)
def NoneFunc():
return None
#noinspection PyUnusedLocal
def parse(httpc, orig_response, **kwargs):
# content is a form from which I get the SAMLResponse
response = RResponse(orig_response)
form = pick_form(response, **kwargs)
#form.backwards_compatible = False
if not form:
raise Exception("Can't pick a form !!")
return {"SAMLResponse": form["SAMLResponse"],
"RelayState":form["RelayState"]}
#noinspection PyUnusedLocal
def interaction(args):
_type = args["type"]
if _type == "form":
return select_form
elif _type == "link":
return chose
elif _type == "response":
return parse
else:
return NoneFunc
# ========================================================================
class Operation(object):
def __init__(self, args=None):
if args:
self.function = interaction(args)
self.args = args or {}
self.request = None
def update(self, dic):
self.args.update(dic)
#noinspection PyUnusedLocal
def post_op(self, result, environ, args):
pass
def __call__(self, httpc, environ, trace, location, response, content,
features):
try:
_args = self.args.copy()
except (KeyError, AttributeError):
_args = {}
_args["_trace_"] = trace
_args["location"] = location
_args["features"] = features
if trace:
trace.reply("FUNCTION: %s" % self.function.__name__)
trace.reply("ARGS: %s" % _args)
result = self.function(httpc, response, **_args)
self.post_op(result, environ, _args)
return result

View File

@@ -1,95 +0,0 @@
from check import CheckHTTPResponse
__author__ = 'rolandh'
class Request():
request = ""
method = ""
lax = False
_request_args= {}
kw_args = {}
tests = {"post": [CheckHTTPResponse], "pre":[]}
def __init__(self, cconf=None):
self.cconf = cconf
self.request_args = self._request_args.copy()
#noinspection PyUnusedLocal
def __call__(self, environ, trace, location, response, content, features):
_client = environ["client"]
try:
kwargs = self.kw_args.copy()
except KeyError:
kwargs = {}
func = getattr(_client, "do_%s" % self.request)
ht_add = None
if "authn_method" in kwargs:
h_arg = _client.init_authentication_method(cis, **kwargs)
else:
h_arg = None
url, body, ht_args, cis = _client.uri_and_body(request, cis,
method=self.method,
request_args=_req)
environ["cis"].append(cis)
if h_arg:
ht_args.update(h_arg)
if ht_add:
ht_args.update({"headers": ht_add})
if trace:
try:
oro = unpack(cis["request"])[1]
trace.request("OpenID Request Object: %s" % oro)
except KeyError:
pass
trace.request("URL: %s" % url)
trace.request("BODY: %s" % body)
try:
trace.request("HEADERS: %s" % ht_args["headers"])
except KeyError:
pass
response = _client.http_request(url, method=self.method, data=body,
**ht_args)
if trace:
trace.reply("RESPONSE: %s" % response)
trace.reply("CONTENT: %s" % response.text)
if response.status_code in [301, 302]:
trace.reply("LOCATION: %s" % response.headers["location"])
trace.reply("COOKIES: %s" % response.cookies)
# try:
# trace.reply("HeaderCookies: %s" % response.headers["set-cookie"])
# except KeyError:
# pass
return url, response, response.text
def update(self, dic):
_tmp = {"request": self.request_args.copy(), "kw": self.kw_args}
for key, val in self.rec_update(_tmp, dic).items():
setattr(self, "%s_args" % key, val)
def rec_update(self, dic0, dic1):
res = {}
for key, val in dic0.items():
if key not in dic1:
res[key] = val
else:
if isinstance(val, dict):
res[key] = self.rec_update(val, dic1[key])
else:
res[key] = dic1[key]
for key, val in dic1.items():
if key in dic0:
continue
else:
res[key] = val
return res

71
src/idp_test/saml2int.py Normal file
View File

@@ -0,0 +1,71 @@
from saml2 import BINDING_HTTP_REDIRECT
from saml2 import BINDING_HTTP_POST
from saml2.saml import NAMEID_FORMAT_PERSISTENT
from idp_test.check import CheckSaml2IntMetaData
from idp_test.check import CheckSaml2IntAttributes
from idp_test.check import CheckSubjectNameIDFormat
from idp_test.check import CheckLogoutSupport
__author__ = 'rolandh'
class Request(object):
_args = {}
def __init__(self):
self.args = self._args.copy()
def setup(self, environ):
pass
class AuthnRequest(Request):
request = "authn_request"
_args = {"binding": BINDING_HTTP_REDIRECT,
"nameid_format": NAMEID_FORMAT_PERSISTENT,
"allow_create": True}
tests = {"pre": [CheckSaml2IntMetaData],
"post": [CheckSaml2IntAttributes,
# CheckSubjectNameIDFormat
]}
class AuthnRequestPost(AuthnRequest):
def __init__(self):
AuthnRequest.__init__(self)
self.args["binding"] = BINDING_HTTP_POST
class LogOutRequest(Request):
request = "logout_request"
tests = {"pre": [CheckLogoutSupport]}
_args = {"binding": BINDING_HTTP_REDIRECT,
# "sign": True
}
def setup(self, environ):
resp = environ["response"]
subj = resp.assertion.subject
self.args["subject_id"] = subj.name_id.text
#self.args["name_id"] = subj.name_id
self.args["issuer_entity_id"] = resp.assertion.issuer.text
OPERATIONS = {
'basic-authn': {
"name": 'Absolute basic SAML2 AuthnRequest',
"descr": ('AuthnRequest using HTTP-redirect'),
"sequence": [AuthnRequest],
#"endpoints": ["authorization_endpoint"],
#"block": ["key_export"]
},
'basic-authn-post': {
"name": 'Basic SAML2 AuthnRequest using HTTP POST',
"descr": ('AuthnRequest using HTTP-POST'),
"sequence": [AuthnRequestPost],
#"endpoints": ["authorization_endpoint"],
#"block": ["key_export"]
},
'log-in-out': {
"name": 'Absolute basic SAML2 AuthnRequest',
"descr": ('AuthnRequest using HTTP-redirect'),
"sequence": [AuthnRequest, LogOutRequest],
}
}

199
src/idp_test/test.py Normal file
View File

@@ -0,0 +1,199 @@
import cookielib
import inspect
from saml2 import BINDING_HTTP_REDIRECT
from saml2 import BINDING_HTTP_POST
from saml2 import BINDING_SOAP
from saml2.binding import http_redirect_message
from saml2.binding import http_post_message
from saml2.binding import send_using_soap
from saml2.client import Saml2Client
from saml2.config import SPConfig
from saml2.s_utils import rndstr
from saml2.saml import NAMEID_FORMAT_PERSISTENT
from saml2.metadata import REQ2SRV
import time
from idp_test.interaction import Operation
from idp_test.interaction import pick_interaction
from idp_test.check import factory
from idp_test import SAML2
__author__ = 'rolandh'
class FatalError(Exception):
pass
ORDER = ["url", "response", "content"]
def check_severity(stat):
if stat["status"] >= 4:
raise FatalError
def intermit(client, response, httpc, environ, trace, cjar, interaction,
test_output, features=None):
if response.status_code >= 400:
done = True
else:
done = False
while not done:
while response.status_code in [302, 301, 303]:
url = response.headers["location"]
trace.reply("REDIRECT TO: %s" % url)
# If back to me
for_me = False
acs = client.config.getattr("endpoints",
"sp")["assertion_consumer_service"]
for redirect_uri in acs:
if url.startswith(redirect_uri):
# Back at the RP
environ["client"].cookiejar = cjar["rp"]
for_me=True
if for_me:
done = True
break
else:
try:
response = httpc.request(url, "GET", trace=trace)
except Exception, err:
raise FatalError("%s" % err)
content = response.text
environ.update({"url": url, "response": response,
"content":content})
check = factory("check-http-response")()
stat = check(environ, test_output)
check_severity(stat)
if done:
break
_base = url.split("?")[0]
try:
_spec = pick_interaction(interaction, _base, content)
except KeyError:
chk = factory("interaction-needed")()
chk(environ, test_output)
raise FatalError()
if len(_spec) > 2:
trace.info(">> %s <<" % _spec["page-type"])
if _spec["page-type"] == "login":
environ["login"] = content
_op = Operation(_spec["control"])
try:
response = _op(httpc, environ, trace, url, response, content,
features)
if isinstance(response, dict):
return response
content = response.text
environ.update({"url": url, "response": response,
"content":content})
check = factory("check-http-response")()
stat = check(environ, test_output)
check_severity(stat)
except FatalError:
raise
except Exception, err:
environ["exception"] = err
chk = factory("exception")()
chk(environ, test_output)
return response
def do_query(config, oper, httpc, trace, interaction):
environ = {}
test_output = []
client = Saml2Client(config)
query = oper.request
args = oper.args
cjar = {"browser": cookielib.CookieJar(),
"rp": cookielib.CookieJar(),
"service": cookielib.CookieJar()}
httpc.cookiejar = cjar["browser"]
locations = getattr(client.metadata, REQ2SRV[query])(args["entity_id"],
args["binding"])
relay_state = rndstr()
_response_func = getattr(client, "%s_response" % query)
response_args = {}
qargs = args.copy()
qfunc = getattr(client, "create_%s" % query)
# remove args the create function can't handle
fargs = inspect.getargspec(qfunc).args
for arg in qargs.keys():
if arg not in fargs:
del qargs[arg]
resp = None
for loc in locations:
qargs["destination"] = loc
req = qfunc(**qargs)
_req_str = "%s" % req
# depending on binding send the query
if args["binding"] is BINDING_HTTP_REDIRECT:
(head, _body) = http_redirect_message(_req_str, loc, relay_state)
res = httpc.request(head[0][1], "GET")
response_args["outstanding"] = {req.id: "/"}
# head should contain a redirect
# deal with redirect, should in the end give me a response
response = intermit(client, res, httpc, environ, trace, cjar,
interaction, test_output, features)
if isinstance(response, dict):
assert relay_state == response["RelayState"]
elif args["binding"] is BINDING_HTTP_POST:
(head, response) = http_post_message(_req_str, loc, relay_state)
elif args["binding"] is BINDING_SOAP:
response = send_using_soap(_req_str, loc, client.config.key_file,
client.config.cert_file,
ca_certs=client.config.ca_certs)
else:
response = None
if response:
try:
_ = _response_func(response, **response_args)
break
except Exception, err:
environ["exception"] = err
chk = factory("exception")()
chk(environ, test_output)
return test_output, "%s" % trace
# ========================================================================
class Request(object):
_args = {}
def __init__(self):
self.args = self._args.copy()
class AuthnRequest(Request):
request = "authn_request"
_args = {"binding": BINDING_HTTP_REDIRECT,
"nameid_format": NAMEID_FORMAT_PERSISTENT}
# ========================================================================
if __name__ == "__main__":
s = SAML2()
s.run()