Allow for specifying the test definition to not be in saml2base

This commit is contained in:
Roland Hedberg
2013-05-30 20:30:57 +02:00
parent dd1c8abbdb
commit 6ccfb9517d
2 changed files with 51 additions and 4 deletions

View File

@@ -1,9 +1,9 @@
#!/usr/bin/env python
__author__ = 'rohe0002'
from idp_test import saml2base
#from idp_test import saml2base
from idp_test import SAML2client
from idp_test.check import factory
cli = SAML2client(saml2base, factory)
cli = SAML2client(factory)
cli.run()

View File

@@ -1,9 +1,12 @@
from importlib import import_module
import json
import os
import types
import argparse
import sys
import logging
import imp
from saml2.client import Saml2Client
from saml2.config import SPConfig
@@ -42,11 +45,45 @@ streamhandler = logging.StreamHandler(sys.stderr)
memoryhandler = logging.handlers.MemoryHandler(1024*10, logging.DEBUG)
def recursive_find_module(name, path=None):
parts = name.split(".")
mod_a = None
for part in parts:
try:
(fil, pathname, desc) = imp.find_module(part, path)
except ImportError:
raise
mod_a = imp.load_module(name, fil, pathname, desc)
sys.modules[name] = mod_a
path = mod_a.__path__
return mod_a
def get_mod(name, path=None):
try:
mod_a = sys.modules[name]
if not isinstance(mod_a, types.ModuleType):
raise KeyError
except KeyError:
try:
(fil, pathname, desc) = imp.find_module(name, path)
mod_a = imp.load_module(name, fil, pathname, desc)
except ImportError:
if "." in name:
mod_a = recursive_find_module(name, path)
else:
raise
sys.modules[name] = mod_a
return mod_a
class SAML2client(object):
def __init__(self, operations, check_factory):
def __init__(self, check_factory):
self.trace = Trace()
self.operations = operations
self.tests = None
self.check_factory = check_factory
@@ -76,12 +113,15 @@ class SAML2client(object):
help="Path to the configuration file for the SP")
self._parser.add_argument("-t", dest="testpackage",
help="Module describing tests")
self._parser.add_argument("-O", dest="operations",
help="Tests")
self._parser.add_argument("oper", nargs="?", help="Which test to run")
self.interactions = None
self.entity_id = None
self.sp_config = None
self.constraints = {}
self.operations = None
def json_config_file(self):
if self.args.json_config_file == "-":
@@ -165,6 +205,13 @@ class SAML2client(object):
HANDLER = ""
self.args = self._parser.parse_args()
if self.args.operations:
path, name = os.path.split(self.args.operations)
self.operations = get_mod(name, [path])
else:
self.operations = __import__("idp_test.saml2base",
fromlist=["idp_test"])
if self.args.metadata:
return self.make_meta()
elif self.args.list: