Allow for specifying the test definition to not be in saml2base

This commit is contained in:
Roland Hedberg
2013-05-30 20:30:57 +02:00
parent dd1c8abbdb
commit 6ccfb9517d
2 changed files with 51 additions and 4 deletions

View File

@@ -1,9 +1,9 @@
#!/usr/bin/env python #!/usr/bin/env python
__author__ = 'rohe0002' __author__ = 'rohe0002'
from idp_test import saml2base #from idp_test import saml2base
from idp_test import SAML2client from idp_test import SAML2client
from idp_test.check import factory from idp_test.check import factory
cli = SAML2client(saml2base, factory) cli = SAML2client(factory)
cli.run() cli.run()

View File

@@ -1,9 +1,12 @@
from importlib import import_module from importlib import import_module
import json import json
import os
import types
import argparse import argparse
import sys import sys
import logging import logging
import imp
from saml2.client import Saml2Client from saml2.client import Saml2Client
from saml2.config import SPConfig from saml2.config import SPConfig
@@ -42,11 +45,45 @@ streamhandler = logging.StreamHandler(sys.stderr)
memoryhandler = logging.handlers.MemoryHandler(1024*10, logging.DEBUG) memoryhandler = logging.handlers.MemoryHandler(1024*10, logging.DEBUG)
def recursive_find_module(name, path=None):
parts = name.split(".")
mod_a = None
for part in parts:
try:
(fil, pathname, desc) = imp.find_module(part, path)
except ImportError:
raise
mod_a = imp.load_module(name, fil, pathname, desc)
sys.modules[name] = mod_a
path = mod_a.__path__
return mod_a
def get_mod(name, path=None):
try:
mod_a = sys.modules[name]
if not isinstance(mod_a, types.ModuleType):
raise KeyError
except KeyError:
try:
(fil, pathname, desc) = imp.find_module(name, path)
mod_a = imp.load_module(name, fil, pathname, desc)
except ImportError:
if "." in name:
mod_a = recursive_find_module(name, path)
else:
raise
sys.modules[name] = mod_a
return mod_a
class SAML2client(object): class SAML2client(object):
def __init__(self, operations, check_factory): def __init__(self, check_factory):
self.trace = Trace() self.trace = Trace()
self.operations = operations
self.tests = None self.tests = None
self.check_factory = check_factory self.check_factory = check_factory
@@ -76,12 +113,15 @@ class SAML2client(object):
help="Path to the configuration file for the SP") help="Path to the configuration file for the SP")
self._parser.add_argument("-t", dest="testpackage", self._parser.add_argument("-t", dest="testpackage",
help="Module describing tests") help="Module describing tests")
self._parser.add_argument("-O", dest="operations",
help="Tests")
self._parser.add_argument("oper", nargs="?", help="Which test to run") self._parser.add_argument("oper", nargs="?", help="Which test to run")
self.interactions = None self.interactions = None
self.entity_id = None self.entity_id = None
self.sp_config = None self.sp_config = None
self.constraints = {} self.constraints = {}
self.operations = None
def json_config_file(self): def json_config_file(self):
if self.args.json_config_file == "-": if self.args.json_config_file == "-":
@@ -165,6 +205,13 @@ class SAML2client(object):
HANDLER = "" HANDLER = ""
self.args = self._parser.parse_args() self.args = self._parser.parse_args()
if self.args.operations:
path, name = os.path.split(self.args.operations)
self.operations = get_mod(name, [path])
else:
self.operations = __import__("idp_test.saml2base",
fromlist=["idp_test"])
if self.args.metadata: if self.args.metadata:
return self.make_meta() return self.make_meta()
elif self.args.list: elif self.args.list: