Adding authn context support

This commit is contained in:
Roland Hedberg
2013-04-21 16:40:02 +02:00
parent c0d9144172
commit 40041b642e
8 changed files with 13242 additions and 54 deletions

View File

@@ -0,0 +1,48 @@
__author__ = 'rolandh'
INTERNETPROTOCOLPASSWORD = \
'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword'
MOBILETWOFACTORCONTRACT = \
'urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract'
PASSWORDPROTECTEDTRANSPORT = \
'urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport'
PASSWORD = 'urn:oasis:names:tc:SAML:2.0:ac:classes:Password'
TLSCLIENT = 'urn:oasis:names:tc:SAML:2.0:ac:classes:TLSClient'
class Authn(object):
def __init__(self):
self.db = {}
def add(self, endpoint, spec, target):
"""
Adds a new authentication endpoint.
:param endpoint: The service endpoint URL
:param spec: What the authentication endpoint offers in the form
of an AuthnContext
:param target: The URL of the authentication service
:return:
"""
try:
_endpspec = self.db[endpoint]
except KeyError:
self.db[endpoint] = {}
_endpspec = self.db[endpoint]
if spec.authn_context_class_ref:
_endpspec[spec.authn_context_class_ref.text] = target
elif spec.authn_context_decl:
_endpspec[
spec.authn_context_decl.c_namespace] = spec.authn_context_decl
def pick(self, endpoint, authn_context):
"""
Given which endpoint the request came in over and what
authentication context is defined find out where to send the user next.
:param endpoint: The service endpoint URL
:param authn_context: An AuthnContext instance
:return: An URL
"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -60,6 +60,7 @@ class ExtensionsType_(SamlBase):
c_child_order = SamlBase.c_child_order[:]
c_cardinality = SamlBase.c_cardinality.copy()
def extensions_type__from_string(xml_string):
return saml2.create_class_from_xml_string(ExtensionsType_, xml_string)

View File

@@ -6,16 +6,11 @@ import getopt
import imp
import sys
import types
import errno
__version__ = 0.4
__version__ = 0.5
try:
from xml.etree import cElementTree as ElementTree
except ImportError:
try:
import cElementTree as ElementTree
except ImportError:
from elementtree import ElementTree
from xml.etree import cElementTree as ElementTree
INDENT = 4*" "
DEBUG = False
@@ -47,6 +42,7 @@ def class_pyify(ref):
PROTECTED_KEYWORDS = ["import", "def", "if", "else", "return", "for",
"while", "not", "try", "except", "in"]
def def_init(imports, attributes):
indent = INDENT+INDENT
indent3 = INDENT+INDENT+INDENT
@@ -77,6 +73,7 @@ def def_init(imports, attributes):
line.append("%s):" % indent)
return line
def base_init(imports):
line = []
indent4 = INDENT+INDENT+INDENT+INDENT
@@ -104,6 +101,7 @@ def base_init(imports):
line.append("%s)" % indent4)
return line
def initialize(attributes):
indent = INDENT+INDENT
line = []
@@ -121,6 +119,7 @@ def initialize(attributes):
line.append("%sself.%s=%s" % (indent, _name, _vname))
return line
def _mod_typ(prop):
try:
(mod, typ) = prop.type
@@ -139,6 +138,7 @@ def _mod_typ(prop):
return mod, typ
def _mod_cname(prop, cdict):
if hasattr(prop, "scoped"):
cname = prop.class_name
@@ -155,6 +155,7 @@ def _mod_cname(prop, cdict):
return mod, cname
def leading_uppercase(string):
try:
return string[0].upper()+string[1:]
@@ -163,6 +164,7 @@ def leading_uppercase(string):
except TypeError:
return ""
def leading_lowercase(string):
try:
return string[0].lower()+string[1:]
@@ -171,6 +173,7 @@ def leading_lowercase(string):
except TypeError:
return ""
def rm_duplicates(properties):
keys = []
clist = []
@@ -189,12 +192,14 @@ def rm_duplicates(properties):
# res.append(item)
# return res
def klass_namn(obj):
if obj.class_name:
return obj.class_name
else:
return obj.name
class PyObj(object):
def __init__(self, name=None, pyname=None, root=None):
self.name = name
@@ -622,7 +627,7 @@ class PyType(PyObj):
self.namespace = namespace
def text(self, target_namespace, cdict, _child=True, ignore=None,
_session=None):
_session=None):
if not self.properties and not self.type \
and not self.superior:
self.done = True
@@ -715,6 +720,8 @@ class PyAttribute(PyObj):
self.namespace = namespace
self.base = None
self.type = typ
self.fixed = False
self.default = None
def text(self, _target_namespace, cdict, _child=True):
if isinstance(self.type, PyObj):
@@ -752,6 +759,7 @@ class PyGroup(object):
self.root = root
self.properties = []
self.done = False
self.ref = []
def text(self, _target_namespace, _dict, _child, _ignore):
return [], []
@@ -837,6 +845,9 @@ def _namespace_and_tag(obj, param, top):
except ValueError:
namespace = ""
tag = param
# except AttributeError:
# namespace = ""
# tag = obj.name
return namespace, tag
@@ -852,6 +863,7 @@ class Simple(object):
self.use = None
self.ref = None
self.scoped = False
self.itemType = None
for attribute, value in elem.attrib.iteritems():
self.__setattr__(attribute, value)
@@ -995,7 +1007,11 @@ class MaxExclusive(Simple):
class List(Simple):
pass
class Include(Simple):
pass
# -----------------------------------------------------------------------------
def sequence(elem):
@@ -1218,7 +1234,7 @@ class Element(Complex):
objekt.type = typ
objekt.value_type = {"base": typ}
except AttributeError, exc:
except AttributeError:
# neither type nor reference, definitely local
if hasattr(self, "parts"):
if len(self.parts) == 1:
@@ -1252,6 +1268,7 @@ class Element(Complex):
return objekt
class SimpleType(Complex):
def repr(self, top=None, _sup=None, _argv=None, _child=True, parent=""):
if self.py_class:
@@ -1284,6 +1301,7 @@ class SimpleType(Complex):
self.py_class = obj
return obj
class Sequence(Complex):
def collect(self, top, sup, argv=None, parent=""):
argv_copy = sd_copy(argv)
@@ -1295,15 +1313,23 @@ class Sequence(Complex):
print "#Sequence: %s" % argv
return Complex.collect(self, top, sup, argv_copy, parent)
class SimpleContent(Complex):
pass
class ComplexContent(Complex):
pass
class Key(Complex):
pass
class Redefine(Complex):
pass
class Extension(Complex):
def collect(self, top, sup, argv=None, parent=""):
if self._own or self._inherited:
@@ -1330,8 +1356,8 @@ class Extension(Complex):
#print "#EXT..-", ia
self._inherited = iattr
except (AttributeError, ValueError):
pass
base = None
self._extend(top, sup, argv, parent, base)
return self._own, self._inherited
@@ -1525,6 +1551,7 @@ def pyify_0(name):
res += "_"
return res
def pyify(name):
# AssertionIDRef
res = []
@@ -1532,7 +1559,7 @@ def pyify(name):
upc = []
pre = ""
for char in name:
if char >= "A" and char <= "Z":
if "A" <= char <= "Z":
upc.append(char)
elif char == "-":
upc.append("_")
@@ -1558,6 +1585,7 @@ def pyify(name):
return "".join(res)
def get_type_def( typ, defs):
for cdef in defs:
try:
@@ -1599,6 +1627,7 @@ def sort_elements(els):
return res, els
def output(elem, target_namespace, eldict, ignore=None):
done = 0
@@ -1631,6 +1660,7 @@ def output(elem, target_namespace, eldict, ignore=None):
return done
def intro():
print """#!/usr/bin/env python
@@ -1644,6 +1674,7 @@ from saml2 import SamlBase
#NAMESPACE = 'http://www.w3.org/2000/09/xmldsig#'
def block_items(objekt, block, eldict):
if objekt not in block:
if isinstance(objekt.type, PyType):
@@ -1666,6 +1697,8 @@ def find_parent(elm, eldict):
return find_parent(sup, eldict)
elif elm.ref:
sup = eldict[elm.ref]
if sup.name == elm.name:
return elm
return find_parent(sup, eldict)
else:
if elm.superior:
@@ -1676,6 +1709,7 @@ def find_parent(elm, eldict):
return elm
class Schema(Complex):
def __init__(self, elem, impo, add, modul, defs):
@@ -1942,6 +1976,8 @@ _MAP = {
"selector": Selector,
"field": Field,
"key": Key,
"include": Include,
"redefine": Redefine
}
ELEMENTFUNCTION = {}
@@ -1961,7 +1997,6 @@ def evaluate(typ, elem):
NS_MAP = "xmlns_map"
def parse_nsmap(fil):
events = "start", "start-ns", "end-ns"
root = None
@@ -2015,46 +2050,48 @@ def get_mod(name, path=None):
raise
sys.modules[name] = mod_a
return mod_a
def main(argv):
try:
opts, args = getopt.getopt(argv, "a:d:hi:I:",
["add=", "help", "import=", "defs="])
except getopt.GetoptError, err:
# print help information and exit:
print str(err) # will print something like "option -a not recognized"
usage()
sys.exit(2)
add = []
defs = []
impo = {}
modul = {}
ignore = []
for opt, arg in opts:
if opt in ("-a", "--add"):
add.append(arg)
elif opt in ("-d", "--defs"):
defs.append(arg)
elif opt in ("-h", "--help"):
usage()
sys.exit()
elif opt in ("-i", "--import"):
mod = get_mod(arg, ['.'])
modul[mod.NAMESPACE] = mod
impo[mod.NAMESPACE] = [arg, None]
elif opt in ("-I", "--ignore"):
ignore.append(arg)
else:
assert False, "unhandled option"
def recursive_add_xmlns_map(_sch, base):
for _part in _sch.parts:
_part.xmlns_map.update(base.xmlns_map)
if isinstance(_part, Complex):
recursive_add_xmlns_map(_part, base)
if not args:
print "No XSD-file specified"
usage()
sys.exit(2)
tree = parse_nsmap(args[0])
def find_and_replace(base, mods):
base.xmlns_map = mods.xmlns_map
recursive_add_xmlns_map(base, mods)
rm = []
for part in mods.parts:
try:
_name = part.name
except AttributeError:
continue
for _part in base.parts:
try:
if _name == _part.name:
rm.append(_part)
except AttributeError:
continue
for part in rm:
base.parts.remove(part)
base.parts.extend(mods.parts)
return base
def read_schema(doc, add, defs, impo, modul, ignore, sdir):
for path in sdir:
fil = "%s%s" % (path, doc)
try:
fp = open(fil)
fp.close()
break
except IOError as e:
if e.errno == errno.EACCES:
continue
else:
raise Exception("Could not find schema file")
tree = parse_nsmap(fil)
known = NAMESPACE_BASE[:]
known.append(XML_NAMESPACE)
@@ -2072,9 +2109,79 @@ def main(argv):
continue
else:
raise Exception("Undefined namespace: %s" % namespace)
schema = Schema(tree._root, impo, add, modul, defs)
_schema = Schema(tree._root, impo, add, modul, defs)
_included_parts = []
_remove_parts = []
_replace = []
for part in _schema.parts:
if isinstance(part, Include):
_sch = read_schema(part.schemaLocation, add, defs, impo, modul,
ignore, sdir)
# Add namespace information
recursive_add_xmlns_map(_sch, _schema)
_included_parts.extend(_sch.parts)
_remove_parts.append(part)
elif isinstance(part, Redefine):
# This is the schema that is going to be redefined
_redef = read_schema(part.schemaLocation, add, defs, impo, modul,
ignore, sdir)
# so find and replace
# Use the schema to be redefined as starting point
_replacement = find_and_replace(_redef, part)
_replace.append((part, _replacement.parts))
for part in _remove_parts:
_schema.parts.remove(part)
_schema.parts.extend(_included_parts)
if _replace:
for vad, med in _replace:
_schema.parts.remove(vad)
_schema.parts.extend(med)
return _schema
def main(argv):
try:
opts, args = getopt.getopt(argv, "a:d:hi:I:s:",
["add=", "help", "import=", "defs="])
except getopt.GetoptError, err:
# print help information and exit:
print str(err) # will print something like "option -a not recognized"
usage()
sys.exit(2)
add = []
defs = []
impo = {}
modul = {}
ignore = []
sdir = ["./"]
for opt, arg in opts:
if opt in ("-a", "--add"):
add.append(arg)
elif opt in ("-d", "--defs"):
defs.append(arg)
elif opt in ("-h", "--help"):
usage()
sys.exit()
elif opt in ("-s", "--schemadir"):
sdir.append(arg)
elif opt in ("-i", "--import"):
mod = get_mod(arg, ['.'])
modul[mod.NAMESPACE] = mod
impo[mod.NAMESPACE] = [arg, None]
elif opt in ("-I", "--ignore"):
ignore.append(arg)
else:
assert False, "unhandled option"
if not args:
print "No XSD-file specified"
usage()
sys.exit(2)
schema = read_schema(args[0], add, defs, impo, modul, ignore, sdir)
#print schema.__dict__
schema.out()