Updated to comply with changes in pysaml2.
This commit is contained in:
10
script/saml2c.py
Executable file
10
script/saml2c.py
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
__author__ = 'rohe0002'
|
||||
|
||||
from idp_test import saml2int
|
||||
from idp_test import SAML2client
|
||||
|
||||
cli = SAML2client(saml2int)
|
||||
|
||||
cli.run()
|
||||
3
setup.py
3
setup.py
@@ -20,7 +20,7 @@ from setuptools import setup
|
||||
__author__ = 'rohe0002'
|
||||
|
||||
setup(
|
||||
name="oic",
|
||||
name="saml2test",
|
||||
version="0.3.0",
|
||||
description="SAML2 test tool",
|
||||
author = "Roland Hedberg",
|
||||
@@ -37,4 +37,5 @@ setup(
|
||||
"beautifulsoup4"],
|
||||
|
||||
zip_safe=False,
|
||||
scripts=["script/saml2c.py"]
|
||||
)
|
||||
@@ -1,11 +1,39 @@
|
||||
from importlib import import_module
|
||||
import json
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
|
||||
import logging
|
||||
|
||||
from saml2.config import SPConfig
|
||||
|
||||
from idp_test.base import FatalError
|
||||
from idp_test.base import do_sequence
|
||||
from idp_test.httpreq import HTTPC
|
||||
#from saml2.config import Config
|
||||
from saml2.mdstore import MetadataStore, MetaData
|
||||
|
||||
# Schemas supported
|
||||
from saml2 import md
|
||||
from saml2 import saml
|
||||
from saml2.extension import mdui
|
||||
from saml2.extension import idpdisc
|
||||
from saml2.extension import dri
|
||||
from saml2.extension import mdattr
|
||||
from saml2.extension import ui
|
||||
from saml2.metadata import entity_descriptor
|
||||
import xmldsig
|
||||
import xmlenc
|
||||
|
||||
SCHEMA = [ dri, idpdisc, md, mdattr, mdui, saml, ui, xmldsig, xmlenc]
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
import traceback
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
def exception_trace(tag, exc, log=None):
|
||||
message = traceback.format_exception(*sys.exc_info())
|
||||
if log:
|
||||
@@ -15,15 +43,58 @@ def exception_trace(tag, exc, log=None):
|
||||
print >> sys.stderr, "[%s] ExcList: %s" % (tag, "".join(message),)
|
||||
print >> sys.stderr, "[%s] Exception: %s" % (tag, exc)
|
||||
|
||||
class SAML2(object):
|
||||
client_args = ["client_id", "redirect_uris", "password"]
|
||||
class Trace(object):
|
||||
def __init__(self):
|
||||
self.trace = []
|
||||
self.start = time.time()
|
||||
|
||||
def __init__(self, operations_mod, client_class, msgfactory):
|
||||
self.operations_mod = operations_mod
|
||||
self.client_class = client_class
|
||||
self.client = None
|
||||
#self.trace = Trace()
|
||||
self.msgfactory = msgfactory
|
||||
def request(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f --> %s" % (delta, msg))
|
||||
|
||||
def reply(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f <-- %s" % (delta, msg))
|
||||
|
||||
def info(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f %s" % (delta, msg))
|
||||
|
||||
def error(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f [ERROR] %s" % (delta, msg))
|
||||
|
||||
def warning(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f [WARNING] %s" % (delta, msg))
|
||||
|
||||
def __str__(self):
|
||||
try:
|
||||
return "\n".join([t.encode("utf-8") for t in self.trace])
|
||||
except UnicodeDecodeError:
|
||||
arr = []
|
||||
for t in self.trace:
|
||||
try:
|
||||
arr.append(t.encode("utf-8"))
|
||||
except UnicodeDecodeError:
|
||||
arr.append(t)
|
||||
return "\n".join(arr)
|
||||
|
||||
def clear(self):
|
||||
self.trace = []
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.trace[item]
|
||||
|
||||
def next(self):
|
||||
for line in self.trace:
|
||||
yield line
|
||||
|
||||
class SAML2client(object):
|
||||
|
||||
def __init__(self, operations):
|
||||
self.trace = Trace()
|
||||
self.operations = operations
|
||||
|
||||
self._parser = argparse.ArgumentParser()
|
||||
self._parser.add_argument('-d', dest='debug', action='store_true',
|
||||
@@ -31,35 +102,21 @@ class SAML2(object):
|
||||
self._parser.add_argument('-v', dest='verbose', action='store_true',
|
||||
help="Print runtime information")
|
||||
self._parser.add_argument('-C', dest="ca_certs",
|
||||
help="CA certs to use to verify HTTPS server certificates, if HTTPS is used and no server CA certs are defined then no cert verification is done")
|
||||
help="CA certs to use to verify HTTPS server certificates, if HTTPS is used and no server CA certs are defined then no cert verification will be done")
|
||||
self._parser.add_argument('-J', dest="json_config_file",
|
||||
help="Script configuration")
|
||||
self._parser.add_argument('-S', dest="sp_id", help="SP id")
|
||||
self._parser.add_argument("-s", dest="list_sp_id", action="store_true",
|
||||
help="List all the SP variants as a JSON object")
|
||||
self._parser.add_argument('-m', dest="metadata", action='store_true',
|
||||
help="Return the SP metadata")
|
||||
self._parser.add_argument("-l", dest="list", action="store_true",
|
||||
help="List all the test flows as a JSON object")
|
||||
self._parser.add_argument("-H", dest="host", default="example.com",
|
||||
help="Which host the script is running on, used to construct the key export URL")
|
||||
self._parser.add_argument("flow", nargs="?", help="Which test flow to run")
|
||||
self._parser.add_argument("oper", nargs="?", help="Which test to run")
|
||||
|
||||
self.args = None
|
||||
self.pinfo = None
|
||||
self.sequences = []
|
||||
self.function_args = {}
|
||||
self.signing_key = None
|
||||
self.encryption_key = None
|
||||
self.test_log = []
|
||||
self.environ = {}
|
||||
self._pop = None
|
||||
|
||||
def parse_args(self):
|
||||
self.json_config= self.json_config_file()
|
||||
|
||||
try:
|
||||
self.features = self.json_config["features"]
|
||||
except KeyError:
|
||||
self.features = {}
|
||||
|
||||
self.pinfo = self.provider_info()
|
||||
self.client_conf(self.client_args)
|
||||
self.interactions = None
|
||||
self.entity_id = None
|
||||
self.sp_config = None
|
||||
|
||||
def json_config_file(self):
|
||||
if self.args.json_config_file == "-":
|
||||
@@ -67,6 +124,34 @@ class SAML2(object):
|
||||
else:
|
||||
return json.loads(open(self.args.json_config_file).read())
|
||||
|
||||
def sp_configure(self, metadata_construction=False):
|
||||
sys.path.insert(0, ".")
|
||||
mod = import_module("config_file")
|
||||
if self.args.sp_id is None:
|
||||
if len(mod.CONFIG) == 1:
|
||||
self.args.sp_id = mod.CONFIG.keys()[0]
|
||||
else:
|
||||
raise Exception("SP id undefined")
|
||||
|
||||
self.sp_config = SPConfig().load(mod.CONFIG[self.args.sp_id],
|
||||
metadata_construction)
|
||||
|
||||
def setup(self):
|
||||
self.json_config= self.json_config_file()
|
||||
|
||||
_jc = self.json_config
|
||||
|
||||
self.interactions = _jc["interaction"]
|
||||
self.entity_id = _jc["entity_id"]
|
||||
|
||||
self.sp_configure()
|
||||
|
||||
metadata = MetadataStore(SCHEMA, self.sp_config.attribute_converters,
|
||||
self.sp_config.xmlsec_binary)
|
||||
metadata[0] = MetaData(SCHEMA, self.sp_config.attribute_converters,
|
||||
_jc["metadata"])
|
||||
self.sp_config.metadata = metadata
|
||||
|
||||
def test_summation(self, id):
|
||||
status = 0
|
||||
for item in self.test_log:
|
||||
@@ -91,75 +176,45 @@ class SAML2(object):
|
||||
def run(self):
|
||||
self.args = self._parser.parse_args()
|
||||
|
||||
if self.args.list:
|
||||
return self.operations()
|
||||
if self.args.metadata:
|
||||
return self.make_meta()
|
||||
elif self.args.list_sp_id:
|
||||
return self.list_conf_id()
|
||||
elif self.args.list:
|
||||
return self.list_operations()
|
||||
else:
|
||||
if not self.args.flow:
|
||||
raise Exception("Missing flow specification")
|
||||
self.args.flow = self.args.flow.strip("'")
|
||||
self.args.flow = self.args.flow.strip('"')
|
||||
if not self.args.oper:
|
||||
raise Exception("Missing test case specification")
|
||||
self.args.oper = self.args.oper.strip("'")
|
||||
self.args.oper = self.args.oper.strip('"')
|
||||
|
||||
flow_spec = self.operations_mod.FLOWS[self.args.flow]
|
||||
self.setup()
|
||||
|
||||
try:
|
||||
try:
|
||||
block = flow_spec["block"]
|
||||
oper = self.operations.OPERATIONS[self.args.oper]
|
||||
except KeyError:
|
||||
block = {}
|
||||
|
||||
self.parse_args()
|
||||
_seq = self.make_sequence()
|
||||
interact = self.get_interactions()
|
||||
|
||||
try:
|
||||
self.do_features(interact, _seq, block)
|
||||
except Exception,exc:
|
||||
exception_trace("do_features", exc)
|
||||
_output = {"status": 4,
|
||||
"tests": [{"status": 4,
|
||||
"message":"Couldn't run testflow: %s" % exc,
|
||||
"id": "verify_features",
|
||||
"name": "Make sure you don't do things you shouldn't"}]}
|
||||
#print >> sys.stdout, json.dumps(_output)
|
||||
print >> sys.stderr, "Undefined testcase"
|
||||
return
|
||||
|
||||
tests = self.get_test()
|
||||
self.client.state = "STATE0"
|
||||
testres, trace = do_sequence(self.sp_config, oper, HTTPC(),
|
||||
self.trace, self.interactions,
|
||||
entity_id=self.json_config["entity_id"])
|
||||
self.test_log = testres
|
||||
sum = self.test_summation(self.args.oper)
|
||||
print >>sys.stdout, json.dumps(sum)
|
||||
if sum["status"] > 1 or self.args.debug:
|
||||
print >> sys.stderr, trace
|
||||
except FatalError:
|
||||
pass
|
||||
except Exception, err:
|
||||
print >> sys.stderr, self.trace
|
||||
print err
|
||||
exception_trace("RUN", err)
|
||||
|
||||
self.environ.update({"provider_info": self.pinfo,
|
||||
"client": self.client})
|
||||
|
||||
try:
|
||||
except_exception = flow_spec["except_exception"]
|
||||
except KeyError:
|
||||
except_exception = False
|
||||
|
||||
try:
|
||||
if self.args.verbose:
|
||||
print >> sys.stderr, "Set up done, running sequence"
|
||||
testres, trace = run_sequence(self.client, _seq, self.trace,
|
||||
interact, self.msgfactory,
|
||||
self.environ, tests,
|
||||
self.json_config["features"],
|
||||
self.args.verbose, self.cconf,
|
||||
except_exception)
|
||||
self.test_log.extend(testres)
|
||||
sum = self.test_summation(self.args.flow)
|
||||
print >>sys.stdout, json.dumps(sum)
|
||||
if sum["status"] > 1 or self.args.debug:
|
||||
print >>sys.stderr, trace
|
||||
except Exception, err:
|
||||
#print >> sys.stderr, self.trace
|
||||
print err
|
||||
exception_trace("RUN", err)
|
||||
|
||||
#if self._pop is not None:
|
||||
# self._pop.terminate()
|
||||
if "keyprovider" in self.environ and self.environ["keyprovider"]:
|
||||
# os.kill(self.environ["keyprovider"].pid, signal.SIGTERM)
|
||||
self.environ["keyprovider"].terminate()
|
||||
|
||||
def operations(self):
|
||||
def list_operations(self):
|
||||
lista = []
|
||||
for key,val in self.operations_mod.FLOWS.items():
|
||||
for key,val in self.operations.OPERATIONS.items():
|
||||
item = {"id": key,
|
||||
"name": val["name"],}
|
||||
try:
|
||||
@@ -178,112 +233,17 @@ class SAML2(object):
|
||||
pass
|
||||
|
||||
lista.append(item)
|
||||
|
||||
print json.dumps(lista)
|
||||
|
||||
def provider_info(self):
|
||||
# Should provide a Metadata class
|
||||
res = {}
|
||||
_jc = self.json_config["provider"]
|
||||
def _get_operation(self, operation):
|
||||
return self.operations.OPERATIONS[operation]
|
||||
|
||||
# Backward compatible
|
||||
if "endpoints" in _jc:
|
||||
try:
|
||||
for endp, url in _jc["endpoints"].items():
|
||||
res[endp] = url
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
for key in ProviderConfigurationResponse.c_param.keys():
|
||||
try:
|
||||
res[key] = _jc[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return res
|
||||
|
||||
def do_features(self, *args):
|
||||
pass
|
||||
|
||||
def export(self):
|
||||
pass
|
||||
|
||||
def client_conf(self, cprop):
|
||||
if self.args.ca_certs:
|
||||
self.client = self.client_class(ca_certs=self.args.ca_certs)
|
||||
else:
|
||||
try:
|
||||
self.client = self.client_class(
|
||||
ca_certs=self.json_config["ca_certs"])
|
||||
except (KeyError, TypeError):
|
||||
self.client = self.client_class()
|
||||
|
||||
#self.client.http_request = self.client.http.crequest
|
||||
|
||||
# set the endpoints in the Client from the provider information
|
||||
# If they are statically configured, if dynamic it happens elsewhere
|
||||
for key, val in self.pinfo.items():
|
||||
if key.endswith("_endpoint"):
|
||||
setattr(self.client, key, val)
|
||||
|
||||
# Client configuration
|
||||
self.cconf = self.json_config["client"]
|
||||
# replace pattern with real value
|
||||
_h = self.args.host
|
||||
self.cconf["redirect_uris"] = [p % _h for p in self.cconf["redirect_uris"]]
|
||||
|
||||
try:
|
||||
self.client.client_prefs = self.cconf["preferences"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# set necessary information in the Client
|
||||
for prop in cprop:
|
||||
try:
|
||||
setattr(self.client, prop, self.cconf[prop])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def make_sequence(self):
|
||||
# Whatever is specified on the command line takes precedences
|
||||
if self.args.flow:
|
||||
sequence = flow2sequence(self.operations_mod, self.args.flow)
|
||||
elif self.json_config and "flow" in self.json_config:
|
||||
sequence = flow2sequence(self.operations_mod,
|
||||
self.json_config["flow"])
|
||||
else:
|
||||
sequence = None
|
||||
|
||||
return sequence
|
||||
|
||||
def get_interactions(self):
|
||||
interactions = []
|
||||
|
||||
if self.json_config:
|
||||
try:
|
||||
interactions = self.json_config["interaction"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if self.args.interactions:
|
||||
_int = self.args.interactions.replace("\'", '"')
|
||||
if interactions:
|
||||
interactions.update(json.loads(_int))
|
||||
else:
|
||||
interactions = json.loads(_int)
|
||||
|
||||
return interactions
|
||||
|
||||
def get_test(self):
|
||||
if self.args.flow:
|
||||
flow = self.operations_mod.FLOWS[self.args.flow]
|
||||
elif self.json_config and "flow" in self.json_config:
|
||||
flow = self.operations_mod.FLOWS[self.json_config["flow"]]
|
||||
else:
|
||||
flow = None
|
||||
|
||||
try:
|
||||
return flow["tests"]
|
||||
except KeyError:
|
||||
return []
|
||||
def make_meta(self):
|
||||
self.sp_configure(True)
|
||||
print entity_descriptor(self.sp_config)
|
||||
|
||||
def list_conf_id(self):
|
||||
sys.path.insert(0, ".")
|
||||
mod = import_module("config_file")
|
||||
_res = dict([(key, cnf["description"]) for key, cnf in mod.CONFIG.items()])
|
||||
print json.dumps(_res)
|
||||
|
||||
@@ -1,392 +1,281 @@
|
||||
#!/usr/bin/env python
|
||||
from check import ExpectedError
|
||||
from check import factory
|
||||
import base64
|
||||
import inspect
|
||||
from saml2 import BINDING_HTTP_REDIRECT
|
||||
from saml2 import BINDING_HTTP_POST
|
||||
from saml2 import BINDING_SOAP
|
||||
from saml2.client import Saml2Client
|
||||
|
||||
#from idp_test.check import ExpectedError
|
||||
from saml2.mdstore import REQ2SRV
|
||||
from saml2.pack import http_redirect_message
|
||||
from saml2.s_utils import rndstr
|
||||
from idp_test.check import factory
|
||||
from idp_test.check import STATUSCODE
|
||||
from idp_test.interaction import Operation
|
||||
from idp_test.interaction import pick_interaction
|
||||
|
||||
__author__ = 'rohe0002'
|
||||
|
||||
import time
|
||||
import cookielib
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
class FatalError(Exception):
|
||||
pass
|
||||
|
||||
class Trace(object):
|
||||
def __init__(self):
|
||||
self.trace = []
|
||||
self.start = time.time()
|
||||
def form_post(request, relay_state):
|
||||
return "SAMLRequest=%s&RelayState=%s" % (base64.b64encode("%s" % request),
|
||||
relay_state)
|
||||
|
||||
def request(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f --> %s" % (delta, msg))
|
||||
|
||||
def reply(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f <-- %s" % (delta, msg))
|
||||
|
||||
def info(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f %s" % (delta, msg))
|
||||
|
||||
def error(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f [ERROR] %s" % (delta, msg))
|
||||
|
||||
def warning(self, msg):
|
||||
delta = time.time() - self.start
|
||||
self.trace.append("%f [WARNING] %s" % (delta, msg))
|
||||
|
||||
def __str__(self):
|
||||
return "\n". join([t.encode("utf-8") for t in self.trace])
|
||||
|
||||
def clear(self):
|
||||
self.trace = []
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.trace[item]
|
||||
|
||||
def next(self):
|
||||
for line in self.trace:
|
||||
yield line
|
||||
|
||||
def flow2sequence(operations, item):
|
||||
flow = operations.FLOWS[item]
|
||||
return [operations.PHASES[phase] for phase in flow["sequence"]]
|
||||
|
||||
def endpoint(client, base):
|
||||
for _endp in client._endpoints:
|
||||
if getattr(client, _endp) == base:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_severity(stat):
|
||||
def check_severity(stat, trace):
|
||||
if stat["status"] >= 4:
|
||||
trace.error("WHERE: %s" % stat["id"])
|
||||
trace.error("STATUS:%s" % STATUSCODE[stat["status"]])
|
||||
try:
|
||||
trace.error("HTTP STATUS: %s" % stat["http_status"])
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
trace.error("INFO: %s" % stat["message"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
raise FatalError
|
||||
|
||||
|
||||
def pick_interaction(interactions, _base="", content="", req=None):
|
||||
unic = content
|
||||
if content:
|
||||
_bs = BeautifulSoup(content)
|
||||
def intermit(client, response, httpc, environ, trace, cjar, interaction,
|
||||
test_output, features=None):
|
||||
if response.status_code >= 400:
|
||||
done = True
|
||||
else:
|
||||
_bs = None
|
||||
done = False
|
||||
|
||||
for interaction in interactions:
|
||||
_match = 0
|
||||
for attr, val in interaction["matches"].items():
|
||||
if attr == "url":
|
||||
if val == _base:
|
||||
_match += 1
|
||||
elif attr == "title":
|
||||
if _bs is None:
|
||||
break
|
||||
if _bs.title is None:
|
||||
break
|
||||
if val in _bs.title.contents:
|
||||
_match += 1
|
||||
elif attr == "content":
|
||||
if unic and val in unic:
|
||||
_match += 1
|
||||
elif attr == "class":
|
||||
if req and val == req:
|
||||
_match += 1
|
||||
url = response.url
|
||||
content = response.text
|
||||
|
||||
if _match == len(interaction["matches"]):
|
||||
return interaction
|
||||
while not done:
|
||||
while response.status_code in [302, 301, 303]:
|
||||
url = response.headers["location"]
|
||||
|
||||
raise KeyError("No interaction matched")
|
||||
trace.reply("REDIRECT TO: %s" % url)
|
||||
# If back to me
|
||||
for_me = False
|
||||
acs = client.config.getattr("endpoints",
|
||||
"sp")["assertion_consumer_service"]
|
||||
for redirect_uri in acs:
|
||||
if url.startswith(redirect_uri):
|
||||
# Back at the RP
|
||||
environ["client"].cookiejar = cjar["rp"]
|
||||
for_me=True
|
||||
|
||||
ORDER = ["url", "response", "content"]
|
||||
if for_me:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
try:
|
||||
response = httpc.request(url, "GET", trace=trace)
|
||||
except Exception, err:
|
||||
raise FatalError("%s" % err)
|
||||
|
||||
def run_sequence(client, sequence, trace, interaction, msgfactory,
|
||||
environ=None, tests=None, features=None, verbose=False,
|
||||
cconf=None, except_exception=None):
|
||||
item = []
|
||||
response = None
|
||||
content = None
|
||||
url = ""
|
||||
content = response.text
|
||||
trace.reply("CONTENT: %s" % content)
|
||||
environ.update({"url": url, "response": response,
|
||||
"content":content})
|
||||
|
||||
check = factory("check-http-response")()
|
||||
stat = check(environ, test_output)
|
||||
check_severity(stat, trace)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
_base = url.split("?")[0]
|
||||
|
||||
try:
|
||||
_spec = pick_interaction(interaction, _base, content)
|
||||
except KeyError:
|
||||
chk = factory("interaction-needed")()
|
||||
chk(environ, test_output)
|
||||
raise FatalError()
|
||||
|
||||
if len(_spec) > 2:
|
||||
trace.info(">> %s <<" % _spec["page-type"])
|
||||
if _spec["page-type"] == "login":
|
||||
environ["login"] = content
|
||||
|
||||
_op = Operation(_spec["control"])
|
||||
|
||||
try:
|
||||
response = _op(httpc, environ, trace, url, response, content,
|
||||
features)
|
||||
if isinstance(response, dict):
|
||||
return response
|
||||
content = response.text
|
||||
environ.update({"url": url, "response": response,
|
||||
"content":content})
|
||||
|
||||
check = factory("check-http-response")()
|
||||
stat = check(environ, test_output)
|
||||
check_severity(stat, trace)
|
||||
except FatalError:
|
||||
raise
|
||||
except Exception, err:
|
||||
environ["exception"] = err
|
||||
chk = factory("exception")()
|
||||
chk(environ, test_output)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def do_sequence(config, oper, httpc, trace, interaction, entity_id,
|
||||
features=None):
|
||||
"""
|
||||
|
||||
:param config: SP configuration
|
||||
:param oper: A dictionary describing the operations to perform
|
||||
:param httpc: A HTTP Client instance
|
||||
:param trace: A Trace instance that keep all the trace information
|
||||
:param interaction: A list of interaction definitions
|
||||
:param entity_id: The entity_id of the IdP
|
||||
:param features: ?
|
||||
:returns: A 2-tuple (testoutput, tracelog)
|
||||
"""
|
||||
|
||||
client = Saml2Client(config)
|
||||
test_output = []
|
||||
_keystore = client.keystore
|
||||
features = features or {}
|
||||
if client.metadata.entities_descr["-"]:
|
||||
environ = {"metadata": client.metadata.entities_descr["-"]}
|
||||
else:
|
||||
environ = {"metadata": client.metadata.entity_descr["-"]}
|
||||
|
||||
cjar = {"browser": cookielib.CookieJar(),
|
||||
"rp": cookielib.CookieJar(),
|
||||
"service": cookielib.CookieJar()}
|
||||
|
||||
environ["sequence"] = sequence
|
||||
environ["cis"] = []
|
||||
environ["trace"] = trace
|
||||
environ["responses"] = []
|
||||
environ["FatalError"] = False
|
||||
for op in oper["sequence"]:
|
||||
output = do_query(client, op(), httpc, trace, interaction, entity_id,
|
||||
environ, cjar, features)
|
||||
test_output.extend(output)
|
||||
if environ["FatalError"]:
|
||||
break
|
||||
return test_output, "%s" % trace
|
||||
|
||||
|
||||
def do_query(client, oper, httpc, trace, interaction, entity_id, environ, cjar,
|
||||
features=None):
|
||||
"""
|
||||
|
||||
:param client: A SAML2 client
|
||||
:param oper: A Request class instance
|
||||
:param httpc: A HTTP Client instance
|
||||
:param trace: A Trace instance that keep all the trace information
|
||||
:param interaction: A list of interaction definitions
|
||||
:param entity_id: The entity_id of the IdP
|
||||
:param environ: Local environment
|
||||
:param features: ?
|
||||
:returns: A 2-tuple (testoutput, tracelog)
|
||||
"""
|
||||
|
||||
oper.setup(environ)
|
||||
query = oper.request
|
||||
args = oper.args
|
||||
args["entity_id"] = entity_id
|
||||
test_output = []
|
||||
|
||||
try:
|
||||
for creq in sequence:
|
||||
req = creq()
|
||||
cfunc = getattr(client, "create_%s" % req.request)
|
||||
if trace:
|
||||
trace.info(70*"=")
|
||||
|
||||
for test in oper.tests["pre"]:
|
||||
chk = test()
|
||||
stat = chk(environ, test_output)
|
||||
try:
|
||||
_pretests = req.tests["pre"]
|
||||
for test in _pretests:
|
||||
chk = test()
|
||||
stat = chk(environ, test_output)
|
||||
check_severity(stat)
|
||||
except KeyError:
|
||||
pass
|
||||
check_severity(stat, trace)
|
||||
except FatalError:
|
||||
environ["FatalError"] = True
|
||||
raise
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
httpc.cookiejar = cjar["browser"]
|
||||
|
||||
locations = getattr(client.metadata, REQ2SRV[query])(args["entity_id"],
|
||||
args["binding"])
|
||||
|
||||
relay_state = rndstr()
|
||||
_response_func = getattr(client, "%s_response" % query)
|
||||
response_args = {}
|
||||
qargs = args.copy()
|
||||
|
||||
qfunc = getattr(client, "create_%s" % query)
|
||||
# remove args the create function can't handle
|
||||
fargs = inspect.getargspec(qfunc).args
|
||||
for arg in qargs.keys():
|
||||
if arg not in fargs:
|
||||
del qargs[arg]
|
||||
|
||||
resp = None
|
||||
for loc in locations:
|
||||
qargs["destination"] = loc
|
||||
|
||||
req = qfunc(**qargs)
|
||||
environ["request"] = req
|
||||
_req_str = "%s" % req
|
||||
trace.info("SAML Request: %s" % _req_str)
|
||||
# depending on binding send the query
|
||||
|
||||
if args["binding"] is BINDING_HTTP_REDIRECT:
|
||||
(head, _body) = http_redirect_message(_req_str, loc, relay_state)
|
||||
res = httpc.request(head[0][1], "GET")
|
||||
response_args["outstanding"] = {req.id: "/"}
|
||||
# head should contain a redirect
|
||||
# deal with redirect, should in the end give me a response
|
||||
try:
|
||||
response = cfunc(**req.args)
|
||||
response = intermit(client, res, httpc, environ, trace, cjar,
|
||||
interaction, test_output, features)
|
||||
except FatalError:
|
||||
environ["FatalError"] = True
|
||||
response = None
|
||||
|
||||
if isinstance(response, dict):
|
||||
assert relay_state == response["RelayState"]
|
||||
elif args["binding"] is BINDING_HTTP_POST:
|
||||
body = form_post(_req_str, relay_state)
|
||||
res = httpc.request(loc, "POST", data=body)
|
||||
response_args["outstanding"] = {req.id: "/"}
|
||||
# head should contain a redirect
|
||||
# deal with redirect, should in the end give me a response
|
||||
try:
|
||||
response = intermit(client, res, httpc, environ, trace, cjar,
|
||||
interaction, test_output, features)
|
||||
except FatalError:
|
||||
environ["FatalError"] = True
|
||||
response = None
|
||||
|
||||
elif args["binding"] is BINDING_SOAP:
|
||||
response = client.send_using_soap(_req_str, loc,
|
||||
client.config.key_file,
|
||||
client.config.cert_file,
|
||||
ca_certs=client.config.ca_certs)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response:
|
||||
try:
|
||||
_resp = _response_func(response, **response_args)
|
||||
environ["response"] = _resp
|
||||
trace.info("SAML Response: %s" % _resp)
|
||||
try:
|
||||
for test in req.tests["post"]:
|
||||
if isinstance(test, tuple):
|
||||
test, kwargs = test
|
||||
else:
|
||||
kwargs = {}
|
||||
chk = test(**kwargs)
|
||||
for test in oper.tests["post"]:
|
||||
chk = test()
|
||||
stat = chk(environ, test_output)
|
||||
check_severity(stat)
|
||||
if isinstance(chk, ExpectedError):
|
||||
item.append(stat["temp"])
|
||||
del stat["temp"]
|
||||
url = None
|
||||
break
|
||||
check_severity(stat, trace)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
except FatalError:
|
||||
raise
|
||||
except FatalError:
|
||||
environ["FatalError"] = True
|
||||
break
|
||||
except Exception, err:
|
||||
environ["exception"] = err
|
||||
chk = factory("exception")()
|
||||
chk(environ, test_output)
|
||||
raise FatalError()
|
||||
|
||||
if not response:
|
||||
continue
|
||||
|
||||
if response.status_code >= 400:
|
||||
done = True
|
||||
elif url:
|
||||
done = False
|
||||
else:
|
||||
done = True
|
||||
|
||||
while not done:
|
||||
while response.status_code in [302, 301, 303]:
|
||||
url = response.headers["location"]
|
||||
|
||||
trace.reply("REDIRECT TO: %s" % url)
|
||||
# If back to me
|
||||
for_me = False
|
||||
for redirect_uri in client.redirect_uris:
|
||||
if url.startswith(redirect_uri):
|
||||
# Back at the RP
|
||||
environ["client"].cookiejar = cjar["rp"]
|
||||
for_me=True
|
||||
|
||||
if for_me:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
try:
|
||||
part = do_request(client, url, "GET", trace=trace)
|
||||
except Exception, err:
|
||||
raise FatalError("%s" % err)
|
||||
environ.update(dict(zip(ORDER, part)))
|
||||
(url, response, content) = part
|
||||
|
||||
check = factory("check-http-response")()
|
||||
stat = check(environ, test_output)
|
||||
check_severity(stat)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
_base = url.split("?")[0]
|
||||
|
||||
try:
|
||||
_spec = pick_interaction(interaction, _base, content)
|
||||
except KeyError:
|
||||
if creq.method == "POST":
|
||||
break
|
||||
elif not req.request in ["AuthorizationRequest",
|
||||
"OpenIDRequest"]:
|
||||
break
|
||||
else:
|
||||
try:
|
||||
_check = getattr(req, "interaction_check")
|
||||
except AttributeError:
|
||||
_check = None
|
||||
|
||||
if _check:
|
||||
chk = factory("interaction-check")()
|
||||
chk(environ, test_output)
|
||||
raise FatalError()
|
||||
else:
|
||||
chk = factory("interaction-needed")()
|
||||
chk(environ, test_output)
|
||||
raise FatalError()
|
||||
|
||||
if len(_spec) > 2:
|
||||
trace.info(">> %s <<" % _spec["page-type"])
|
||||
if _spec["page-type"] == "login":
|
||||
environ["login"] = content
|
||||
|
||||
_op = Operation(_spec["control"])
|
||||
|
||||
try:
|
||||
part = _op(environ, trace, url, response, content, features)
|
||||
environ.update(dict(zip(ORDER, part)))
|
||||
(url, response, content) = part
|
||||
|
||||
check = factory("check-http-response")()
|
||||
stat = check(environ, test_output)
|
||||
check_severity(stat)
|
||||
except FatalError:
|
||||
raise
|
||||
except Exception, err:
|
||||
environ["exception"] = err
|
||||
chk = factory("exception")()
|
||||
chk(environ, test_output)
|
||||
raise FatalError
|
||||
|
||||
# if done:
|
||||
# break
|
||||
|
||||
info = None
|
||||
qresp = None
|
||||
resp_type = resp.type
|
||||
if response:
|
||||
try:
|
||||
ctype = response.headers["content-type"]
|
||||
if ctype == "application/jwt":
|
||||
resp_type = "jwt"
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
if response.status_code >= 400:
|
||||
pass
|
||||
elif not url:
|
||||
if isinstance(content, Message):
|
||||
qresp = content
|
||||
elif response.status_code == 200:
|
||||
info = content
|
||||
elif resp.where == "url" or response.status_code == 302:
|
||||
try:
|
||||
info = response.headers["location"]
|
||||
resp_type = "urlencoded"
|
||||
except KeyError:
|
||||
try:
|
||||
_check = getattr(req, "interaction_check", None)
|
||||
except AttributeError:
|
||||
_check = None
|
||||
|
||||
if _check:
|
||||
chk = factory("interaction-check")()
|
||||
chk(environ, test_output)
|
||||
raise FatalError()
|
||||
else:
|
||||
chk = factory("missing-redirect")()
|
||||
stat = chk(environ, test_output)
|
||||
check_severity(stat)
|
||||
else:
|
||||
check = factory("check_content_type_header")()
|
||||
stat = check(environ, test_output)
|
||||
check_severity(stat)
|
||||
info = content
|
||||
|
||||
if info and resp.response:
|
||||
if isinstance(resp.response, basestring):
|
||||
response = msgfactory(resp.response)
|
||||
else:
|
||||
response = resp.response
|
||||
|
||||
chk = factory("response-parse")()
|
||||
environ["response_type"] = response.__name__
|
||||
environ["responses"].append((response, info))
|
||||
try:
|
||||
qresp = client.parse_response(response, info, resp_type,
|
||||
client.state,
|
||||
keystore=_keystore,
|
||||
client_id=client.client_id,
|
||||
scope="openid")
|
||||
if trace and qresp:
|
||||
trace.info("[%s]: %s" % (qresp.type(),
|
||||
qresp.to_dict()))
|
||||
item.append(qresp)
|
||||
environ["response_message"] = qresp
|
||||
err = None
|
||||
except Exception, err:
|
||||
environ["exception"] = "%s" % err
|
||||
qresp = None
|
||||
if err and except_exception:
|
||||
if isinstance(err, except_exception):
|
||||
trace.info("Got expected exception: %s [%s]" % (err,
|
||||
err.__class__.__name__))
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
stat = chk(environ, test_output)
|
||||
check_severity(stat)
|
||||
|
||||
if qresp:
|
||||
try:
|
||||
for test in resp.tests["post"]:
|
||||
if isinstance(test, tuple):
|
||||
test, kwargs = test
|
||||
else:
|
||||
kwargs = {}
|
||||
chk = test(**kwargs)
|
||||
stat = chk(environ, test_output)
|
||||
check_severity(stat)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
resp(environ, qresp)
|
||||
|
||||
if tests is not None:
|
||||
environ["item"] = item
|
||||
for test, args in tests:
|
||||
if isinstance(test, basestring):
|
||||
chk = factory(test)(**args)
|
||||
else:
|
||||
chk = test(**args)
|
||||
try:
|
||||
check_severity(chk(environ, test_output))
|
||||
except Exception, err:
|
||||
raise FatalError("%s" % err)
|
||||
|
||||
except FatalError:
|
||||
pass
|
||||
except Exception, err:
|
||||
environ["exception"] = err
|
||||
chk = factory("exception")()
|
||||
chk(environ, test_output)
|
||||
|
||||
return test_output, "%s" % trace
|
||||
|
||||
|
||||
def run_sequences(client, sequences, trace, interaction,
|
||||
verbose=False):
|
||||
for sequence, endpoints, fid in sequences:
|
||||
# clear cookie cache
|
||||
client.grant.clear()
|
||||
try:
|
||||
client.http.cookiejar.clear()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
err = run_sequence(client, sequence, trace, interaction, verbose)
|
||||
|
||||
if err:
|
||||
print "%s - FAIL" % fid
|
||||
print
|
||||
if not verbose:
|
||||
print trace
|
||||
else:
|
||||
print "%s - OK" % fid
|
||||
|
||||
trace.clear()
|
||||
return test_output
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import inspect
|
||||
import sys
|
||||
import traceback
|
||||
from saml2.md import EntitiesDescriptor
|
||||
from saml2.saml import NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_PERSISTENT
|
||||
from saml2.saml import NAME_FORMAT_URI
|
||||
from saml2.sigver import cert_from_key_info
|
||||
from saml2.sigver import key_from_key_value
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
@@ -67,6 +73,18 @@ class Other(CriticalError):
|
||||
""" Other error """
|
||||
msg = "Other error"
|
||||
|
||||
class WrapException(CriticalError):
|
||||
"""
|
||||
A runtime exception
|
||||
"""
|
||||
id = "exception"
|
||||
msg = "Test tool exception"
|
||||
|
||||
def _func(self, environ=None):
|
||||
self._status = self.status
|
||||
self._message = traceback.format_exception(*sys.exc_info())
|
||||
return {}
|
||||
|
||||
class CheckHTTPResponse(CriticalError):
|
||||
"""
|
||||
Checks that the HTTP response status is within the 200 or 300 range
|
||||
@@ -76,36 +94,241 @@ class CheckHTTPResponse(CriticalError):
|
||||
|
||||
def _func(self, environ):
|
||||
_response = environ["response"]
|
||||
_content = environ["content"]
|
||||
|
||||
res = {}
|
||||
if _response.status_code >= 400 :
|
||||
self._status = self.status
|
||||
self._message = self.msg
|
||||
# if CONT_JSON in _response.headers["content-type"]:
|
||||
# try:
|
||||
# err = ErrorResponse().deserialize(_content, "json")
|
||||
# self._message = err.to_json()
|
||||
# except Exception:
|
||||
# res["content"] = _content
|
||||
# else:
|
||||
# res["content"] = _content
|
||||
res["url"] = environ["url"]
|
||||
res["http_status"] = _response.status_code
|
||||
else:
|
||||
# might still be an error message
|
||||
try:
|
||||
# err = ErrorResponse().deserialize(_content, "json")
|
||||
# err.verify()
|
||||
# self._message = err.to_json()
|
||||
self._status = self.status
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
res["url"] = environ["url"]
|
||||
res["content"] = _response.text
|
||||
|
||||
return res
|
||||
|
||||
class CheckSaml2IntMetaData(Check):
|
||||
"""
|
||||
Checks that the Metadata follows the profile
|
||||
"""
|
||||
id = "check-saml2int-metadata"
|
||||
msg = "Metadata error"
|
||||
|
||||
def verify_key_info(self, ki):
|
||||
# key_info
|
||||
# one or more key_value and/or x509_data.X509Certificate
|
||||
try:
|
||||
assert ki.key_value or ki.x509_data
|
||||
except AssertionError:
|
||||
self._message = "Missing KeyValue or X509Data.X509Certificate"
|
||||
self._status = CRITICAL
|
||||
return False
|
||||
|
||||
xkeys = cert_from_key_info(ki)
|
||||
vkeys = key_from_key_value(ki)
|
||||
|
||||
if xkeys and vkeys:
|
||||
# verify that it's the same keys
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def verify_key_descriptor(self, kd):
|
||||
# key_info
|
||||
if not self.verify_key_info(kd.key_info):
|
||||
return False
|
||||
|
||||
# use
|
||||
if kd.use:
|
||||
try:
|
||||
assert kd.use in ["encryption", "signing"]
|
||||
except AssertionError:
|
||||
self._message = "Unknown use specification: '%s'" % kd.use.text
|
||||
self._status = CRITICAL
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _func(self, environ):
|
||||
if isinstance(environ["metadata"], EntitiesDescriptor):
|
||||
ed = environ["metadata"].entity_descriptor[0]
|
||||
else:
|
||||
ed = environ["metadata"]
|
||||
|
||||
res = {}
|
||||
|
||||
assert len(ed.idpsso_descriptor)
|
||||
idpsso = ed.idpsso_descriptor[0]
|
||||
for kd in idpsso.key_descriptor:
|
||||
if self.verify_key_descriptor(kd) == False:
|
||||
return res
|
||||
|
||||
# contact person
|
||||
if not idpsso.contact_person:
|
||||
self._message = "Metadata should contain contact person information"
|
||||
self._status = WARNING
|
||||
return res
|
||||
else:
|
||||
item = {"support": False, "technical": False}
|
||||
for contact in idpsso.contact_person:
|
||||
try:
|
||||
item[contact.contact_type] = True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if item["support"] and item["technical"]:
|
||||
pass
|
||||
elif not item["support"] and not item["technical"]:
|
||||
self._message = "Missing technical and support contact information"
|
||||
self._status = WARNING
|
||||
elif item["support"]:
|
||||
self._message = "Missing technical contact information"
|
||||
self._status = WARNING
|
||||
elif item["technical"]:
|
||||
self._message = "Missing support contact information"
|
||||
self._status = WARNING
|
||||
|
||||
if self._message:
|
||||
return res
|
||||
|
||||
# NameID format
|
||||
if not idpsso.nameid_format:
|
||||
self._message = "Metadata should specify NameID format support"
|
||||
self._status = WARNING
|
||||
return res
|
||||
else:
|
||||
# should support Transient
|
||||
item = {NAMEID_FORMAT_TRANSIENT:False}
|
||||
for format in idpsso.nameid_format:
|
||||
try:
|
||||
item[format] = True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if not item[NAMEID_FORMAT_TRANSIENT]:
|
||||
self._message = "IdP should support Transient NameID Format"
|
||||
self._status = WARNING
|
||||
return res
|
||||
|
||||
return res
|
||||
|
||||
class CheckSaml2IntAttributes(Check):
|
||||
"""
|
||||
Any <saml2:Attribute> elements exchanged via any SAML 2.0 messages,
|
||||
assertions, or metadata MUST contain a NameFormat of
|
||||
urn:oasis:names:tc:SAML:2.0:attrname-format:uri.
|
||||
"""
|
||||
id = "check-saml2int-attributes"
|
||||
msg = "Attribute error"
|
||||
|
||||
def _func(self, environ):
|
||||
response = environ["response"]
|
||||
try:
|
||||
opaque_identifier = environ["opaque_identifier"]
|
||||
except KeyError:
|
||||
opaque_identifier = False
|
||||
try:
|
||||
name_format_not_specified = environ["name_format_not_specified"]
|
||||
except KeyError:
|
||||
name_format_not_specified = False
|
||||
|
||||
res = {}
|
||||
|
||||
# should be a list but isn't
|
||||
#assert len(response.assertion) == 1
|
||||
assertion = response.assertion
|
||||
assert len(assertion.authn_statement) == 1
|
||||
assert len(assertion.attribute_statement) < 2
|
||||
|
||||
if assertion.attribute_statement:
|
||||
atrstat = assertion.attribute_statement[0]
|
||||
for attr in atrstat.attribute:
|
||||
try:
|
||||
assert attr.name_format == NAME_FORMAT_URI
|
||||
except AssertionError:
|
||||
self._message = "Attribute name format error"
|
||||
self._status = CRITICAL
|
||||
return res
|
||||
try:
|
||||
assert attr.name.startswith("urn:oid")
|
||||
except AssertionError:
|
||||
self._message = "Attribute name should be an OID"
|
||||
self._status = CRITICAL
|
||||
return res
|
||||
|
||||
assert not assertion.subject.encrypted_id
|
||||
assert not assertion.subject.base_id
|
||||
|
||||
if opaque_identifier:
|
||||
try:
|
||||
assert assertion.subject.name_id.format == NAMEID_FORMAT_PERSISTENT
|
||||
except AssertionError:
|
||||
self._message = "NameID format should be TRANSIENT"
|
||||
self._status = WARNING
|
||||
|
||||
if name_format_not_specified:
|
||||
try:
|
||||
assert assertion.subject.name_id.format == NAMEID_FORMAT_TRANSIENT
|
||||
except AssertionError:
|
||||
self._message = "NameID format should be TRANSIENT"
|
||||
self._status = WARNING
|
||||
|
||||
return res
|
||||
|
||||
class CheckSubjectNameIDFormat(Check):
|
||||
"""
|
||||
The <NameIDPolicy> element tailors the name identifier in the subjects of
|
||||
assertions resulting from an <AuthnRequest>.
|
||||
When this element is used, if the content is not understood by or acceptable
|
||||
to the identity provider, then a <Response> message element MUST be
|
||||
returned with an error <Status>, and MAY contain a second-level
|
||||
<StatusCode> of urn:oasis:names:tc:SAML:2.0:status:InvalidNameIDPolicy.
|
||||
If the Format value is omitted or set to urn:oasis:names:tc:SAML:2.0:nameid-
|
||||
format:unspecified, then the identity provider is free to return any kind
|
||||
of identifier, subject to any additional constraints due to the content of
|
||||
this element or the policies of the identity provider or principal.
|
||||
"""
|
||||
id = "check-saml2int-attributes"
|
||||
msg = "Attribute error"
|
||||
|
||||
def _func(self, environ):
|
||||
response = environ["response"]
|
||||
request = environ["request"]
|
||||
|
||||
res ={}
|
||||
if request.name_id_policy:
|
||||
format = request.name_id_policy.format
|
||||
sp_name_qualifier = request.name_id_policy.sp_name_qualifier
|
||||
|
||||
subj = response.assertion.subject
|
||||
try:
|
||||
assert subj.name_id.format == format
|
||||
if sp_name_qualifier:
|
||||
assert subj.name_id.sp_name_qualifier == sp_name_qualifier
|
||||
except AssertionError:
|
||||
self._message = "The IdP returns wrong NameID format"
|
||||
self._status = CRITICAL
|
||||
|
||||
return res
|
||||
|
||||
class CheckLogoutSupport(Check):
|
||||
id = "check-logout-support"
|
||||
msg = "Does not support logout"
|
||||
|
||||
def _func(self, environ):
|
||||
if isinstance(environ["metadata"], EntitiesDescriptor):
|
||||
ed = environ["metadata"].entity_descriptor[0]
|
||||
else:
|
||||
ed = environ["metadata"]
|
||||
|
||||
assert len(ed.idpsso_descriptor)
|
||||
idpsso = ed.idpsso_descriptor[0]
|
||||
try:
|
||||
assert idpsso.single_logout_service
|
||||
except AssertionError:
|
||||
self._message = self.msg
|
||||
self._status = CRITICAL
|
||||
|
||||
return {}
|
||||
|
||||
def factory(id):
|
||||
for name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||
if inspect.isclass(obj):
|
||||
|
||||
137
src/idp_test/httpreq.py
Normal file
137
src/idp_test/httpreq.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from Cookie import SimpleCookie
|
||||
import cookielib
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
# =============================================================================
|
||||
|
||||
ATTRS = {"version":None,
|
||||
"name":"",
|
||||
"value": None,
|
||||
"port": None,
|
||||
"port_specified": False,
|
||||
"domain": "",
|
||||
"domain_specified": False,
|
||||
"domain_initial_dot": False,
|
||||
"path": "",
|
||||
"path_specified": False,
|
||||
"secure": False,
|
||||
"expires": None,
|
||||
"discard": True,
|
||||
"comment": None,
|
||||
"comment_url": None,
|
||||
"rest": "",
|
||||
"rfc2109": True}
|
||||
|
||||
PAIRS = {
|
||||
"port": "port_specified",
|
||||
"domain": "domain_specified",
|
||||
"path": "path_specified"
|
||||
}
|
||||
|
||||
def _since_epoch(cdate):
|
||||
# date format 'Wed, 06-Jun-2012 01:34:34 GMT'
|
||||
if len(cdate) <= 5:
|
||||
return 0
|
||||
|
||||
cdate = cdate[5:-4]
|
||||
try:
|
||||
t = time.strptime(cdate, "%d-%b-%Y %H:%M:%S")
|
||||
except ValueError:
|
||||
t = time.strptime(cdate, "%d-%b-%y %H:%M:%S")
|
||||
return int(time.mktime(t))
|
||||
|
||||
class HTTPC(object):
|
||||
def __init__(self, ca_certs=None):
|
||||
|
||||
self.request_args = {"allow_redirects": False,}
|
||||
#self.cookies = cookielib.CookieJar()
|
||||
self.cookies = {}
|
||||
self.cookiejar = cookielib.CookieJar()
|
||||
|
||||
if ca_certs:
|
||||
self.request_args["verify"] = True
|
||||
else:
|
||||
self.request_args["verify"] = False
|
||||
|
||||
def _cookies(self):
|
||||
cookie_dict = {}
|
||||
|
||||
for _, a in list(self.cookiejar._cookies.items()):
|
||||
for _, b in list(a.items()):
|
||||
for cookie in list(b.values()):
|
||||
# print cookie
|
||||
cookie_dict[cookie.name] = cookie.value
|
||||
|
||||
return cookie_dict
|
||||
|
||||
def set_cookie(self, kaka, request):
|
||||
"""sets a cookie in a Cookie jar based on a set-cookie header line"""
|
||||
|
||||
# default rfc2109=False
|
||||
# max-age, httponly
|
||||
for cookie_name, morsel in kaka.items():
|
||||
std_attr = ATTRS.copy()
|
||||
std_attr["name"] = cookie_name
|
||||
_tmp = morsel.coded_value
|
||||
if _tmp.startswith('"') and _tmp.endswith('"'):
|
||||
std_attr["value"] = _tmp[1:-1]
|
||||
else:
|
||||
std_attr["value"] = _tmp
|
||||
|
||||
std_attr["version"] = 0
|
||||
# copy attributes that have values
|
||||
for attr in morsel.keys():
|
||||
if attr in ATTRS:
|
||||
if morsel[attr]:
|
||||
if attr == "expires":
|
||||
std_attr[attr]=_since_epoch(morsel[attr])
|
||||
else:
|
||||
std_attr[attr]=morsel[attr]
|
||||
elif attr == "max-age":
|
||||
if morsel["max-age"]:
|
||||
std_attr["expires"] = _since_epoch(morsel["max-age"])
|
||||
|
||||
for att, set in PAIRS.items():
|
||||
if std_attr[att]:
|
||||
std_attr[set] = True
|
||||
|
||||
if std_attr["domain"] and std_attr["domain"].startswith("."):
|
||||
std_attr["domain_initial_dot"] = True
|
||||
|
||||
if morsel["max-age"] is 0:
|
||||
try:
|
||||
self.cookiejar.clear(domain=std_attr["domain"],
|
||||
path=std_attr["path"],
|
||||
name=std_attr["name"])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
new_cookie = cookielib.Cookie(**std_attr)
|
||||
|
||||
self.cookiejar.set_cookie(new_cookie)
|
||||
|
||||
def request(self, url, method="GET", trace=None, **kwargs):
|
||||
_kwargs = copy.copy(self.request_args)
|
||||
if kwargs:
|
||||
_kwargs.update(kwargs)
|
||||
|
||||
if self.cookiejar:
|
||||
_kwargs["cookies"] = self._cookies()
|
||||
if trace:
|
||||
trace.info("SENT COOKIEs: %s" % (_kwargs["cookies"],))
|
||||
r = requests.request(method, url, **_kwargs)
|
||||
try:
|
||||
if trace:
|
||||
trace.info("RECEIVED COOKIEs: %s" % (r.headers["set-cookie"],))
|
||||
self.set_cookie(SimpleCookie(r.headers["set-cookie"]), r)
|
||||
except AttributeError, err:
|
||||
pass
|
||||
|
||||
return r
|
||||
379
src/idp_test/interaction.py
Normal file
379
src/idp_test/interaction.py
Normal file
@@ -0,0 +1,379 @@
|
||||
__author__ = 'rohe0002'
|
||||
|
||||
import json
|
||||
|
||||
from urlparse import urlparse
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from mechanize import ParseResponseEx
|
||||
from mechanize._form import ControlNotFoundError, AmbiguityError
|
||||
from mechanize._form import ListControl
|
||||
|
||||
def pick_interaction(interactions, _base="", content="", req=None):
|
||||
unic = content
|
||||
if content:
|
||||
_bs = BeautifulSoup(content)
|
||||
else:
|
||||
_bs = None
|
||||
|
||||
for interaction in interactions:
|
||||
_match = 0
|
||||
for attr, val in interaction["matches"].items():
|
||||
if attr == "url":
|
||||
if val == _base:
|
||||
_match += 1
|
||||
elif attr == "title":
|
||||
if _bs is None:
|
||||
break
|
||||
if _bs.title is None:
|
||||
break
|
||||
if val in _bs.title.contents:
|
||||
_match += 1
|
||||
elif attr == "content":
|
||||
if unic and val in unic:
|
||||
_match += 1
|
||||
elif attr == "class":
|
||||
if req and val == req:
|
||||
_match += 1
|
||||
|
||||
if _match == len(interaction["matches"]):
|
||||
return interaction
|
||||
|
||||
raise KeyError("No interaction matched")
|
||||
|
||||
class FlowException(Exception):
|
||||
def __init__(self, function="", content="", url=""):
|
||||
Exception.__init__(self)
|
||||
self.function = function
|
||||
self.content = content
|
||||
self.url = url
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.__dict__)
|
||||
|
||||
class RResponse():
|
||||
"""
|
||||
A Response class that behaves in the way that mechanize expects it.
|
||||
Links to a requests.Response
|
||||
"""
|
||||
def __init__(self, resp):
|
||||
self._resp = resp
|
||||
self.index = 0
|
||||
self.text = resp.text
|
||||
if isinstance(self.text, unicode):
|
||||
if resp.encoding == "UTF-8":
|
||||
self.text = self.text.encode("utf-8")
|
||||
else:
|
||||
self.text = self.text.encode("latin-1")
|
||||
self._len = len(self.text)
|
||||
self.url = str(resp.url)
|
||||
self.statuscode = resp.status_code
|
||||
|
||||
def geturl(self):
|
||||
return self._resp.url
|
||||
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return getattr(self._resp, item)
|
||||
except AttributeError:
|
||||
return getattr(self._resp.headers, item)
|
||||
|
||||
def __getattribute__(self, item):
|
||||
try:
|
||||
return getattr(self._resp, item)
|
||||
except AttributeError:
|
||||
return getattr(self._resp.headers, item)
|
||||
|
||||
def read(self, size=0):
|
||||
"""
|
||||
Read from the content of the response. The class remembers what has
|
||||
been read so it's possible to read small consecutive parts of the
|
||||
content.
|
||||
|
||||
:param size: The number of bytes to read
|
||||
:return: Somewhere between zero and 'size' number of bytes depending
|
||||
on how much it left in the content buffer to read.
|
||||
"""
|
||||
if size:
|
||||
if self._len < size:
|
||||
return self.text
|
||||
else:
|
||||
if self._len == self.index:
|
||||
part = None
|
||||
elif self._len - self.index < size:
|
||||
part = self.text[self.index:]
|
||||
self.index = self._len
|
||||
else:
|
||||
part = self.text[self.index:self.index+size]
|
||||
self.index += size
|
||||
return part
|
||||
else:
|
||||
return self.text
|
||||
|
||||
|
||||
def pick_form(response, url=None, **kwargs):
|
||||
"""
|
||||
Picks which form in a web-page that should be used
|
||||
|
||||
:param response: A HTTP request response. A DResponse instance
|
||||
:param content: The HTTP response content
|
||||
:param url: The url the request was sent to
|
||||
:return: The picked form or None of no form matched the criteria.
|
||||
"""
|
||||
_txt = response.text
|
||||
|
||||
forms = ParseResponseEx(response)
|
||||
if not forms:
|
||||
raise FlowException(content=response.text, url=url)
|
||||
|
||||
#if len(forms) == 1:
|
||||
# return forms[0]
|
||||
#else:
|
||||
|
||||
_form = None
|
||||
# ignore the first form, because I use ParseResponseEx which adds
|
||||
# one form at the top of the list
|
||||
forms = forms[1:]
|
||||
if len(forms) == 1:
|
||||
_form = forms[0]
|
||||
else:
|
||||
if "pick" in kwargs:
|
||||
_dict = kwargs["pick"]
|
||||
for form in forms:
|
||||
if _form:
|
||||
break
|
||||
for key, _ava in _dict.items():
|
||||
if key == "form":
|
||||
_keys = form.attrs.keys()
|
||||
for attr, val in _ava.items():
|
||||
if attr in _keys and val == form.attrs[attr]:
|
||||
_form = form
|
||||
elif key == "control":
|
||||
prop = _ava["id"]
|
||||
_default = _ava["value"]
|
||||
try:
|
||||
orig_val = form[prop]
|
||||
if isinstance(orig_val, basestring):
|
||||
if orig_val == _default:
|
||||
_form = form
|
||||
elif _default in orig_val:
|
||||
_form = form
|
||||
except KeyError:
|
||||
pass
|
||||
except ControlNotFoundError:
|
||||
pass
|
||||
elif key == "method":
|
||||
if form.method == _ava:
|
||||
_form = form
|
||||
else:
|
||||
_form = None
|
||||
|
||||
if not _form:
|
||||
break
|
||||
elif "index" in kwargs:
|
||||
_form = forms[int(kwargs["index"])]
|
||||
|
||||
return _form
|
||||
|
||||
def do_click(httpc, form, **kwargs):
|
||||
"""
|
||||
Emulates the user clicking submit on a form.
|
||||
|
||||
:param httpc: The Client instance
|
||||
:param form: The form that should be submitted
|
||||
:return: What do_request() returns
|
||||
"""
|
||||
|
||||
if "click" in kwargs:
|
||||
request=None
|
||||
_name = kwargs["click"]
|
||||
try:
|
||||
_ = form.find_control(name=_name)
|
||||
request = form.click(name=_name)
|
||||
except AmbiguityError:
|
||||
# more than one control with that name
|
||||
_val = kwargs["set"][_name]
|
||||
_nr = 0
|
||||
while True:
|
||||
try:
|
||||
cntrl = form.find_control(name=_name, nr=_nr)
|
||||
if cntrl.value == _val:
|
||||
request = form.click(name=_name, nr=_nr)
|
||||
break
|
||||
else:
|
||||
_nr += 1
|
||||
except ControlNotFoundError:
|
||||
raise Exception("No submit control with the name='%s' and "
|
||||
"value='%s' could be found" % (_name,
|
||||
_val))
|
||||
else:
|
||||
request = form.click()
|
||||
|
||||
headers = {}
|
||||
for key, val in request.unredirected_hdrs.items():
|
||||
headers[key] = val
|
||||
|
||||
url = request._Request__original
|
||||
try:
|
||||
_trace = kwargs["_trace_"]
|
||||
except KeyError:
|
||||
_trace = False
|
||||
|
||||
if form.method == "POST":
|
||||
return httpc.request(url, "POST", data=request.data, headers=headers,
|
||||
trace=_trace)
|
||||
else:
|
||||
return httpc.request(url, "GET", headers=headers, trace=_trace)
|
||||
|
||||
def select_form(httpc, orig_response, **kwargs):
|
||||
"""
|
||||
Pick a form on a web page, possibly enter some information and submit
|
||||
the form.
|
||||
|
||||
:param httpc: A HTTP client instance
|
||||
:param orig_response: The original response (as returned by requests)
|
||||
:return: The response do_click() returns
|
||||
"""
|
||||
response = RResponse(orig_response)
|
||||
try:
|
||||
_url = response.url
|
||||
except KeyError:
|
||||
_url = kwargs["location"]
|
||||
|
||||
form = pick_form(response, _url, **kwargs)
|
||||
#form.backwards_compatible = False
|
||||
if not form:
|
||||
raise Exception("Can't pick a form !!")
|
||||
|
||||
if "set" in kwargs:
|
||||
for key, val in kwargs["set"].items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if "click" in kwargs and kwargs["click"] == key:
|
||||
continue
|
||||
|
||||
try:
|
||||
form[key] = val
|
||||
except ControlNotFoundError:
|
||||
pass
|
||||
except TypeError:
|
||||
cntrl = form.find_control(key)
|
||||
if isinstance(cntrl, ListControl):
|
||||
form[key] = [val]
|
||||
else:
|
||||
raise
|
||||
|
||||
return do_click(httpc, form, **kwargs)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def chose(httpc, orig_response, path, **kwargs):
|
||||
"""
|
||||
Sends a HTTP GET to a url given by the present url and the given
|
||||
relative path.
|
||||
|
||||
:param orig_response: The original response
|
||||
:param content: The content of the response
|
||||
:param path: The relative path to add to the base URL
|
||||
:return: The response do_click() returns
|
||||
"""
|
||||
|
||||
try:
|
||||
_trace = kwargs["trace"]
|
||||
except KeyError:
|
||||
_trace = False
|
||||
|
||||
if not path.startswith("http"):
|
||||
try:
|
||||
_url = orig_response.url
|
||||
except KeyError:
|
||||
_url = kwargs["location"]
|
||||
|
||||
part = urlparse(_url)
|
||||
url = "%s://%s%s" % (part[0], part[1], path)
|
||||
else:
|
||||
url = path
|
||||
|
||||
return httpc.request(url, "GET", trace=_trace)
|
||||
#return resp, ""
|
||||
|
||||
def post_form(httpc, orig_response, **kwargs):
|
||||
"""
|
||||
The same as select_form but with no possibility of change the content
|
||||
of the form.
|
||||
|
||||
:param httpc: A HTTP Client instance
|
||||
:param orig_response: The original response (as returned by requests)
|
||||
:param content: The content of the response
|
||||
:return: The response do_click() returns
|
||||
"""
|
||||
response = RResponse(orig_response)
|
||||
|
||||
form = pick_form(response, **kwargs)
|
||||
|
||||
return do_click(httpc, form, **kwargs)
|
||||
|
||||
def NoneFunc():
|
||||
return None
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def parse(httpc, orig_response, **kwargs):
|
||||
# content is a form from which I get the SAMLResponse
|
||||
response = RResponse(orig_response)
|
||||
|
||||
form = pick_form(response, **kwargs)
|
||||
#form.backwards_compatible = False
|
||||
if not form:
|
||||
raise Exception("Can't pick a form !!")
|
||||
|
||||
return {"SAMLResponse": form["SAMLResponse"],
|
||||
"RelayState":form["RelayState"]}
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def interaction(args):
|
||||
_type = args["type"]
|
||||
if _type == "form":
|
||||
return select_form
|
||||
elif _type == "link":
|
||||
return chose
|
||||
elif _type == "response":
|
||||
return parse
|
||||
else:
|
||||
return NoneFunc
|
||||
|
||||
# ========================================================================
|
||||
|
||||
class Operation(object):
|
||||
def __init__(self, args=None):
|
||||
if args:
|
||||
self.function = interaction(args)
|
||||
|
||||
self.args = args or {}
|
||||
self.request = None
|
||||
|
||||
def update(self, dic):
|
||||
self.args.update(dic)
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def post_op(self, result, environ, args):
|
||||
pass
|
||||
|
||||
def __call__(self, httpc, environ, trace, location, response, content,
|
||||
features):
|
||||
try:
|
||||
_args = self.args.copy()
|
||||
except (KeyError, AttributeError):
|
||||
_args = {}
|
||||
|
||||
_args["_trace_"] = trace
|
||||
_args["location"] = location
|
||||
_args["features"] = features
|
||||
|
||||
if trace:
|
||||
trace.reply("FUNCTION: %s" % self.function.__name__)
|
||||
trace.reply("ARGS: %s" % _args)
|
||||
|
||||
result = self.function(httpc, response, **_args)
|
||||
self.post_op(result, environ, _args)
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
from check import CheckHTTPResponse
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
class Request():
|
||||
request = ""
|
||||
method = ""
|
||||
lax = False
|
||||
_request_args= {}
|
||||
kw_args = {}
|
||||
tests = {"post": [CheckHTTPResponse], "pre":[]}
|
||||
|
||||
def __init__(self, cconf=None):
|
||||
self.cconf = cconf
|
||||
self.request_args = self._request_args.copy()
|
||||
|
||||
#noinspection PyUnusedLocal
|
||||
def __call__(self, environ, trace, location, response, content, features):
|
||||
_client = environ["client"]
|
||||
try:
|
||||
kwargs = self.kw_args.copy()
|
||||
except KeyError:
|
||||
kwargs = {}
|
||||
|
||||
func = getattr(_client, "do_%s" % self.request)
|
||||
|
||||
ht_add = None
|
||||
|
||||
if "authn_method" in kwargs:
|
||||
h_arg = _client.init_authentication_method(cis, **kwargs)
|
||||
else:
|
||||
h_arg = None
|
||||
|
||||
url, body, ht_args, cis = _client.uri_and_body(request, cis,
|
||||
method=self.method,
|
||||
request_args=_req)
|
||||
|
||||
environ["cis"].append(cis)
|
||||
if h_arg:
|
||||
ht_args.update(h_arg)
|
||||
if ht_add:
|
||||
ht_args.update({"headers": ht_add})
|
||||
|
||||
if trace:
|
||||
try:
|
||||
oro = unpack(cis["request"])[1]
|
||||
trace.request("OpenID Request Object: %s" % oro)
|
||||
except KeyError:
|
||||
pass
|
||||
trace.request("URL: %s" % url)
|
||||
trace.request("BODY: %s" % body)
|
||||
try:
|
||||
trace.request("HEADERS: %s" % ht_args["headers"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response = _client.http_request(url, method=self.method, data=body,
|
||||
**ht_args)
|
||||
|
||||
if trace:
|
||||
trace.reply("RESPONSE: %s" % response)
|
||||
trace.reply("CONTENT: %s" % response.text)
|
||||
if response.status_code in [301, 302]:
|
||||
trace.reply("LOCATION: %s" % response.headers["location"])
|
||||
trace.reply("COOKIES: %s" % response.cookies)
|
||||
# try:
|
||||
# trace.reply("HeaderCookies: %s" % response.headers["set-cookie"])
|
||||
# except KeyError:
|
||||
# pass
|
||||
|
||||
return url, response, response.text
|
||||
|
||||
def update(self, dic):
|
||||
_tmp = {"request": self.request_args.copy(), "kw": self.kw_args}
|
||||
for key, val in self.rec_update(_tmp, dic).items():
|
||||
setattr(self, "%s_args" % key, val)
|
||||
|
||||
def rec_update(self, dic0, dic1):
|
||||
res = {}
|
||||
for key, val in dic0.items():
|
||||
if key not in dic1:
|
||||
res[key] = val
|
||||
else:
|
||||
if isinstance(val, dict):
|
||||
res[key] = self.rec_update(val, dic1[key])
|
||||
else:
|
||||
res[key] = dic1[key]
|
||||
|
||||
for key, val in dic1.items():
|
||||
if key in dic0:
|
||||
continue
|
||||
else:
|
||||
res[key] = val
|
||||
|
||||
return res
|
||||
71
src/idp_test/saml2int.py
Normal file
71
src/idp_test/saml2int.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from saml2 import BINDING_HTTP_REDIRECT
|
||||
from saml2 import BINDING_HTTP_POST
|
||||
from saml2.saml import NAMEID_FORMAT_PERSISTENT
|
||||
from idp_test.check import CheckSaml2IntMetaData
|
||||
from idp_test.check import CheckSaml2IntAttributes
|
||||
from idp_test.check import CheckSubjectNameIDFormat
|
||||
from idp_test.check import CheckLogoutSupport
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
class Request(object):
|
||||
_args = {}
|
||||
|
||||
def __init__(self):
|
||||
self.args = self._args.copy()
|
||||
|
||||
def setup(self, environ):
|
||||
pass
|
||||
|
||||
class AuthnRequest(Request):
|
||||
request = "authn_request"
|
||||
_args = {"binding": BINDING_HTTP_REDIRECT,
|
||||
"nameid_format": NAMEID_FORMAT_PERSISTENT,
|
||||
"allow_create": True}
|
||||
tests = {"pre": [CheckSaml2IntMetaData],
|
||||
"post": [CheckSaml2IntAttributes,
|
||||
# CheckSubjectNameIDFormat
|
||||
]}
|
||||
|
||||
class AuthnRequestPost(AuthnRequest):
|
||||
def __init__(self):
|
||||
AuthnRequest.__init__(self)
|
||||
self.args["binding"] = BINDING_HTTP_POST
|
||||
|
||||
|
||||
|
||||
class LogOutRequest(Request):
|
||||
request = "logout_request"
|
||||
tests = {"pre": [CheckLogoutSupport]}
|
||||
_args = {"binding": BINDING_HTTP_REDIRECT,
|
||||
# "sign": True
|
||||
}
|
||||
|
||||
def setup(self, environ):
|
||||
resp = environ["response"]
|
||||
subj = resp.assertion.subject
|
||||
self.args["subject_id"] = subj.name_id.text
|
||||
#self.args["name_id"] = subj.name_id
|
||||
self.args["issuer_entity_id"] = resp.assertion.issuer.text
|
||||
|
||||
OPERATIONS = {
|
||||
'basic-authn': {
|
||||
"name": 'Absolute basic SAML2 AuthnRequest',
|
||||
"descr": ('AuthnRequest using HTTP-redirect'),
|
||||
"sequence": [AuthnRequest],
|
||||
#"endpoints": ["authorization_endpoint"],
|
||||
#"block": ["key_export"]
|
||||
},
|
||||
'basic-authn-post': {
|
||||
"name": 'Basic SAML2 AuthnRequest using HTTP POST',
|
||||
"descr": ('AuthnRequest using HTTP-POST'),
|
||||
"sequence": [AuthnRequestPost],
|
||||
#"endpoints": ["authorization_endpoint"],
|
||||
#"block": ["key_export"]
|
||||
},
|
||||
'log-in-out': {
|
||||
"name": 'Absolute basic SAML2 AuthnRequest',
|
||||
"descr": ('AuthnRequest using HTTP-redirect'),
|
||||
"sequence": [AuthnRequest, LogOutRequest],
|
||||
}
|
||||
}
|
||||
199
src/idp_test/test.py
Normal file
199
src/idp_test/test.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import cookielib
|
||||
import inspect
|
||||
|
||||
from saml2 import BINDING_HTTP_REDIRECT
|
||||
from saml2 import BINDING_HTTP_POST
|
||||
from saml2 import BINDING_SOAP
|
||||
from saml2.binding import http_redirect_message
|
||||
from saml2.binding import http_post_message
|
||||
from saml2.binding import send_using_soap
|
||||
from saml2.client import Saml2Client
|
||||
from saml2.config import SPConfig
|
||||
from saml2.s_utils import rndstr
|
||||
from saml2.saml import NAMEID_FORMAT_PERSISTENT
|
||||
from saml2.metadata import REQ2SRV
|
||||
import time
|
||||
|
||||
from idp_test.interaction import Operation
|
||||
from idp_test.interaction import pick_interaction
|
||||
from idp_test.check import factory
|
||||
|
||||
from idp_test import SAML2
|
||||
|
||||
__author__ = 'rolandh'
|
||||
|
||||
class FatalError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
ORDER = ["url", "response", "content"]
|
||||
|
||||
def check_severity(stat):
|
||||
if stat["status"] >= 4:
|
||||
raise FatalError
|
||||
|
||||
def intermit(client, response, httpc, environ, trace, cjar, interaction,
|
||||
test_output, features=None):
|
||||
if response.status_code >= 400:
|
||||
done = True
|
||||
else:
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
while response.status_code in [302, 301, 303]:
|
||||
url = response.headers["location"]
|
||||
|
||||
trace.reply("REDIRECT TO: %s" % url)
|
||||
# If back to me
|
||||
for_me = False
|
||||
acs = client.config.getattr("endpoints",
|
||||
"sp")["assertion_consumer_service"]
|
||||
for redirect_uri in acs:
|
||||
if url.startswith(redirect_uri):
|
||||
# Back at the RP
|
||||
environ["client"].cookiejar = cjar["rp"]
|
||||
for_me=True
|
||||
|
||||
if for_me:
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
try:
|
||||
response = httpc.request(url, "GET", trace=trace)
|
||||
except Exception, err:
|
||||
raise FatalError("%s" % err)
|
||||
|
||||
content = response.text
|
||||
environ.update({"url": url, "response": response,
|
||||
"content":content})
|
||||
|
||||
check = factory("check-http-response")()
|
||||
stat = check(environ, test_output)
|
||||
check_severity(stat)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
_base = url.split("?")[0]
|
||||
|
||||
try:
|
||||
_spec = pick_interaction(interaction, _base, content)
|
||||
except KeyError:
|
||||
chk = factory("interaction-needed")()
|
||||
chk(environ, test_output)
|
||||
raise FatalError()
|
||||
|
||||
if len(_spec) > 2:
|
||||
trace.info(">> %s <<" % _spec["page-type"])
|
||||
if _spec["page-type"] == "login":
|
||||
environ["login"] = content
|
||||
|
||||
_op = Operation(_spec["control"])
|
||||
|
||||
try:
|
||||
response = _op(httpc, environ, trace, url, response, content,
|
||||
features)
|
||||
if isinstance(response, dict):
|
||||
return response
|
||||
content = response.text
|
||||
environ.update({"url": url, "response": response,
|
||||
"content":content})
|
||||
|
||||
check = factory("check-http-response")()
|
||||
stat = check(environ, test_output)
|
||||
check_severity(stat)
|
||||
except FatalError:
|
||||
raise
|
||||
except Exception, err:
|
||||
environ["exception"] = err
|
||||
chk = factory("exception")()
|
||||
chk(environ, test_output)
|
||||
|
||||
return response
|
||||
|
||||
def do_query(config, oper, httpc, trace, interaction):
|
||||
environ = {}
|
||||
test_output = []
|
||||
client = Saml2Client(config)
|
||||
query = oper.request
|
||||
args = oper.args
|
||||
|
||||
cjar = {"browser": cookielib.CookieJar(),
|
||||
"rp": cookielib.CookieJar(),
|
||||
"service": cookielib.CookieJar()}
|
||||
|
||||
httpc.cookiejar = cjar["browser"]
|
||||
|
||||
locations = getattr(client.metadata, REQ2SRV[query])(args["entity_id"],
|
||||
args["binding"])
|
||||
|
||||
relay_state = rndstr()
|
||||
_response_func = getattr(client, "%s_response" % query)
|
||||
response_args = {}
|
||||
qargs = args.copy()
|
||||
|
||||
qfunc = getattr(client, "create_%s" % query)
|
||||
# remove args the create function can't handle
|
||||
fargs = inspect.getargspec(qfunc).args
|
||||
for arg in qargs.keys():
|
||||
if arg not in fargs:
|
||||
del qargs[arg]
|
||||
|
||||
resp = None
|
||||
for loc in locations:
|
||||
qargs["destination"] = loc
|
||||
|
||||
req = qfunc(**qargs)
|
||||
_req_str = "%s" % req
|
||||
# depending on binding send the query
|
||||
|
||||
if args["binding"] is BINDING_HTTP_REDIRECT:
|
||||
(head, _body) = http_redirect_message(_req_str, loc, relay_state)
|
||||
res = httpc.request(head[0][1], "GET")
|
||||
response_args["outstanding"] = {req.id: "/"}
|
||||
# head should contain a redirect
|
||||
# deal with redirect, should in the end give me a response
|
||||
response = intermit(client, res, httpc, environ, trace, cjar,
|
||||
interaction, test_output, features)
|
||||
if isinstance(response, dict):
|
||||
assert relay_state == response["RelayState"]
|
||||
elif args["binding"] is BINDING_HTTP_POST:
|
||||
(head, response) = http_post_message(_req_str, loc, relay_state)
|
||||
elif args["binding"] is BINDING_SOAP:
|
||||
response = send_using_soap(_req_str, loc, client.config.key_file,
|
||||
client.config.cert_file,
|
||||
ca_certs=client.config.ca_certs)
|
||||
else:
|
||||
response = None
|
||||
|
||||
if response:
|
||||
try:
|
||||
_ = _response_func(response, **response_args)
|
||||
break
|
||||
except Exception, err:
|
||||
environ["exception"] = err
|
||||
chk = factory("exception")()
|
||||
chk(environ, test_output)
|
||||
|
||||
return test_output, "%s" % trace
|
||||
|
||||
# ========================================================================
|
||||
|
||||
class Request(object):
|
||||
_args = {}
|
||||
|
||||
def __init__(self):
|
||||
self.args = self._args.copy()
|
||||
|
||||
class AuthnRequest(Request):
|
||||
request = "authn_request"
|
||||
_args = {"binding": BINDING_HTTP_REDIRECT,
|
||||
"nameid_format": NAMEID_FORMAT_PERSISTENT}
|
||||
|
||||
# ========================================================================
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
s = SAML2()
|
||||
s.run()
|
||||
Reference in New Issue
Block a user