Adding authn context support

This commit is contained in:
Roland Hedberg
2013-04-21 16:40:02 +02:00
parent c0d9144172
commit 40041b642e
8 changed files with 13242 additions and 54 deletions

View File

@@ -0,0 +1,48 @@
__author__ = 'rolandh'
INTERNETPROTOCOLPASSWORD = \
'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword'
MOBILETWOFACTORCONTRACT = \
'urn:oasis:names:tc:SAML:2.0:ac:classes:MobileTwoFactorContract'
PASSWORDPROTECTEDTRANSPORT = \
'urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport'
PASSWORD = 'urn:oasis:names:tc:SAML:2.0:ac:classes:Password'
TLSCLIENT = 'urn:oasis:names:tc:SAML:2.0:ac:classes:TLSClient'
class Authn(object):
def __init__(self):
self.db = {}
def add(self, endpoint, spec, target):
"""
Adds a new authentication endpoint.
:param endpoint: The service endpoint URL
:param spec: What the authentication endpoint offers in the form
of an AuthnContext
:param target: The URL of the authentication service
:return:
"""
try:
_endpspec = self.db[endpoint]
except KeyError:
self.db[endpoint] = {}
_endpspec = self.db[endpoint]
if spec.authn_context_class_ref:
_endpspec[spec.authn_context_class_ref.text] = target
elif spec.authn_context_decl:
_endpspec[
spec.authn_context_decl.c_namespace] = spec.authn_context_decl
def pick(self, endpoint, authn_context):
"""
Given which endpoint the request came in over and what
authentication context is defined find out where to send the user next.
:param endpoint: The service endpoint URL
:param authn_context: An AuthnContext instance
:return: An URL
"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -60,6 +60,7 @@ class ExtensionsType_(SamlBase):
c_child_order = SamlBase.c_child_order[:] c_child_order = SamlBase.c_child_order[:]
c_cardinality = SamlBase.c_cardinality.copy() c_cardinality = SamlBase.c_cardinality.copy()
def extensions_type__from_string(xml_string): def extensions_type__from_string(xml_string):
return saml2.create_class_from_xml_string(ExtensionsType_, xml_string) return saml2.create_class_from_xml_string(ExtensionsType_, xml_string)

View File

@@ -6,16 +6,11 @@ import getopt
import imp import imp
import sys import sys
import types import types
import errno
__version__ = 0.4 __version__ = 0.5
try: from xml.etree import cElementTree as ElementTree
from xml.etree import cElementTree as ElementTree
except ImportError:
try:
import cElementTree as ElementTree
except ImportError:
from elementtree import ElementTree
INDENT = 4*" " INDENT = 4*" "
DEBUG = False DEBUG = False
@@ -47,6 +42,7 @@ def class_pyify(ref):
PROTECTED_KEYWORDS = ["import", "def", "if", "else", "return", "for", PROTECTED_KEYWORDS = ["import", "def", "if", "else", "return", "for",
"while", "not", "try", "except", "in"] "while", "not", "try", "except", "in"]
def def_init(imports, attributes): def def_init(imports, attributes):
indent = INDENT+INDENT indent = INDENT+INDENT
indent3 = INDENT+INDENT+INDENT indent3 = INDENT+INDENT+INDENT
@@ -77,6 +73,7 @@ def def_init(imports, attributes):
line.append("%s):" % indent) line.append("%s):" % indent)
return line return line
def base_init(imports): def base_init(imports):
line = [] line = []
indent4 = INDENT+INDENT+INDENT+INDENT indent4 = INDENT+INDENT+INDENT+INDENT
@@ -104,6 +101,7 @@ def base_init(imports):
line.append("%s)" % indent4) line.append("%s)" % indent4)
return line return line
def initialize(attributes): def initialize(attributes):
indent = INDENT+INDENT indent = INDENT+INDENT
line = [] line = []
@@ -121,6 +119,7 @@ def initialize(attributes):
line.append("%sself.%s=%s" % (indent, _name, _vname)) line.append("%sself.%s=%s" % (indent, _name, _vname))
return line return line
def _mod_typ(prop): def _mod_typ(prop):
try: try:
(mod, typ) = prop.type (mod, typ) = prop.type
@@ -139,6 +138,7 @@ def _mod_typ(prop):
return mod, typ return mod, typ
def _mod_cname(prop, cdict): def _mod_cname(prop, cdict):
if hasattr(prop, "scoped"): if hasattr(prop, "scoped"):
cname = prop.class_name cname = prop.class_name
@@ -155,6 +155,7 @@ def _mod_cname(prop, cdict):
return mod, cname return mod, cname
def leading_uppercase(string): def leading_uppercase(string):
try: try:
return string[0].upper()+string[1:] return string[0].upper()+string[1:]
@@ -163,6 +164,7 @@ def leading_uppercase(string):
except TypeError: except TypeError:
return "" return ""
def leading_lowercase(string): def leading_lowercase(string):
try: try:
return string[0].lower()+string[1:] return string[0].lower()+string[1:]
@@ -171,6 +173,7 @@ def leading_lowercase(string):
except TypeError: except TypeError:
return "" return ""
def rm_duplicates(properties): def rm_duplicates(properties):
keys = [] keys = []
clist = [] clist = []
@@ -189,12 +192,14 @@ def rm_duplicates(properties):
# res.append(item) # res.append(item)
# return res # return res
def klass_namn(obj): def klass_namn(obj):
if obj.class_name: if obj.class_name:
return obj.class_name return obj.class_name
else: else:
return obj.name return obj.name
class PyObj(object): class PyObj(object):
def __init__(self, name=None, pyname=None, root=None): def __init__(self, name=None, pyname=None, root=None):
self.name = name self.name = name
@@ -622,7 +627,7 @@ class PyType(PyObj):
self.namespace = namespace self.namespace = namespace
def text(self, target_namespace, cdict, _child=True, ignore=None, def text(self, target_namespace, cdict, _child=True, ignore=None,
_session=None): _session=None):
if not self.properties and not self.type \ if not self.properties and not self.type \
and not self.superior: and not self.superior:
self.done = True self.done = True
@@ -715,6 +720,8 @@ class PyAttribute(PyObj):
self.namespace = namespace self.namespace = namespace
self.base = None self.base = None
self.type = typ self.type = typ
self.fixed = False
self.default = None
def text(self, _target_namespace, cdict, _child=True): def text(self, _target_namespace, cdict, _child=True):
if isinstance(self.type, PyObj): if isinstance(self.type, PyObj):
@@ -752,6 +759,7 @@ class PyGroup(object):
self.root = root self.root = root
self.properties = [] self.properties = []
self.done = False self.done = False
self.ref = []
def text(self, _target_namespace, _dict, _child, _ignore): def text(self, _target_namespace, _dict, _child, _ignore):
return [], [] return [], []
@@ -837,6 +845,9 @@ def _namespace_and_tag(obj, param, top):
except ValueError: except ValueError:
namespace = "" namespace = ""
tag = param tag = param
# except AttributeError:
# namespace = ""
# tag = obj.name
return namespace, tag return namespace, tag
@@ -852,6 +863,7 @@ class Simple(object):
self.use = None self.use = None
self.ref = None self.ref = None
self.scoped = False self.scoped = False
self.itemType = None
for attribute, value in elem.attrib.iteritems(): for attribute, value in elem.attrib.iteritems():
self.__setattr__(attribute, value) self.__setattr__(attribute, value)
@@ -995,7 +1007,11 @@ class MaxExclusive(Simple):
class List(Simple): class List(Simple):
pass pass
class Include(Simple):
pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def sequence(elem): def sequence(elem):
@@ -1218,7 +1234,7 @@ class Element(Complex):
objekt.type = typ objekt.type = typ
objekt.value_type = {"base": typ} objekt.value_type = {"base": typ}
except AttributeError, exc: except AttributeError:
# neither type nor reference, definitely local # neither type nor reference, definitely local
if hasattr(self, "parts"): if hasattr(self, "parts"):
if len(self.parts) == 1: if len(self.parts) == 1:
@@ -1252,6 +1268,7 @@ class Element(Complex):
return objekt return objekt
class SimpleType(Complex): class SimpleType(Complex):
def repr(self, top=None, _sup=None, _argv=None, _child=True, parent=""): def repr(self, top=None, _sup=None, _argv=None, _child=True, parent=""):
if self.py_class: if self.py_class:
@@ -1284,6 +1301,7 @@ class SimpleType(Complex):
self.py_class = obj self.py_class = obj
return obj return obj
class Sequence(Complex): class Sequence(Complex):
def collect(self, top, sup, argv=None, parent=""): def collect(self, top, sup, argv=None, parent=""):
argv_copy = sd_copy(argv) argv_copy = sd_copy(argv)
@@ -1295,15 +1313,23 @@ class Sequence(Complex):
print "#Sequence: %s" % argv print "#Sequence: %s" % argv
return Complex.collect(self, top, sup, argv_copy, parent) return Complex.collect(self, top, sup, argv_copy, parent)
class SimpleContent(Complex): class SimpleContent(Complex):
pass pass
class ComplexContent(Complex): class ComplexContent(Complex):
pass pass
class Key(Complex): class Key(Complex):
pass pass
class Redefine(Complex):
pass
class Extension(Complex): class Extension(Complex):
def collect(self, top, sup, argv=None, parent=""): def collect(self, top, sup, argv=None, parent=""):
if self._own or self._inherited: if self._own or self._inherited:
@@ -1330,8 +1356,8 @@ class Extension(Complex):
#print "#EXT..-", ia #print "#EXT..-", ia
self._inherited = iattr self._inherited = iattr
except (AttributeError, ValueError): except (AttributeError, ValueError):
pass base = None
self._extend(top, sup, argv, parent, base) self._extend(top, sup, argv, parent, base)
return self._own, self._inherited return self._own, self._inherited
@@ -1525,6 +1551,7 @@ def pyify_0(name):
res += "_" res += "_"
return res return res
def pyify(name): def pyify(name):
# AssertionIDRef # AssertionIDRef
res = [] res = []
@@ -1532,7 +1559,7 @@ def pyify(name):
upc = [] upc = []
pre = "" pre = ""
for char in name: for char in name:
if char >= "A" and char <= "Z": if "A" <= char <= "Z":
upc.append(char) upc.append(char)
elif char == "-": elif char == "-":
upc.append("_") upc.append("_")
@@ -1558,6 +1585,7 @@ def pyify(name):
return "".join(res) return "".join(res)
def get_type_def( typ, defs): def get_type_def( typ, defs):
for cdef in defs: for cdef in defs:
try: try:
@@ -1599,6 +1627,7 @@ def sort_elements(els):
return res, els return res, els
def output(elem, target_namespace, eldict, ignore=None): def output(elem, target_namespace, eldict, ignore=None):
done = 0 done = 0
@@ -1631,6 +1660,7 @@ def output(elem, target_namespace, eldict, ignore=None):
return done return done
def intro(): def intro():
print """#!/usr/bin/env python print """#!/usr/bin/env python
@@ -1644,6 +1674,7 @@ from saml2 import SamlBase
#NAMESPACE = 'http://www.w3.org/2000/09/xmldsig#' #NAMESPACE = 'http://www.w3.org/2000/09/xmldsig#'
def block_items(objekt, block, eldict): def block_items(objekt, block, eldict):
if objekt not in block: if objekt not in block:
if isinstance(objekt.type, PyType): if isinstance(objekt.type, PyType):
@@ -1666,6 +1697,8 @@ def find_parent(elm, eldict):
return find_parent(sup, eldict) return find_parent(sup, eldict)
elif elm.ref: elif elm.ref:
sup = eldict[elm.ref] sup = eldict[elm.ref]
if sup.name == elm.name:
return elm
return find_parent(sup, eldict) return find_parent(sup, eldict)
else: else:
if elm.superior: if elm.superior:
@@ -1676,6 +1709,7 @@ def find_parent(elm, eldict):
return elm return elm
class Schema(Complex): class Schema(Complex):
def __init__(self, elem, impo, add, modul, defs): def __init__(self, elem, impo, add, modul, defs):
@@ -1942,6 +1976,8 @@ _MAP = {
"selector": Selector, "selector": Selector,
"field": Field, "field": Field,
"key": Key, "key": Key,
"include": Include,
"redefine": Redefine
} }
ELEMENTFUNCTION = {} ELEMENTFUNCTION = {}
@@ -1961,7 +1997,6 @@ def evaluate(typ, elem):
NS_MAP = "xmlns_map" NS_MAP = "xmlns_map"
def parse_nsmap(fil): def parse_nsmap(fil):
events = "start", "start-ns", "end-ns" events = "start", "start-ns", "end-ns"
root = None root = None
@@ -2015,46 +2050,48 @@ def get_mod(name, path=None):
raise raise
sys.modules[name] = mod_a sys.modules[name] = mod_a
return mod_a return mod_a
def main(argv):
try:
opts, args = getopt.getopt(argv, "a:d:hi:I:",
["add=", "help", "import=", "defs="])
except getopt.GetoptError, err:
# print help information and exit:
print str(err) # will print something like "option -a not recognized"
usage()
sys.exit(2)
add = []
defs = []
impo = {}
modul = {}
ignore = []
for opt, arg in opts: def recursive_add_xmlns_map(_sch, base):
if opt in ("-a", "--add"): for _part in _sch.parts:
add.append(arg) _part.xmlns_map.update(base.xmlns_map)
elif opt in ("-d", "--defs"): if isinstance(_part, Complex):
defs.append(arg) recursive_add_xmlns_map(_part, base)
elif opt in ("-h", "--help"):
usage()
sys.exit()
elif opt in ("-i", "--import"):
mod = get_mod(arg, ['.'])
modul[mod.NAMESPACE] = mod
impo[mod.NAMESPACE] = [arg, None]
elif opt in ("-I", "--ignore"):
ignore.append(arg)
else:
assert False, "unhandled option"
if not args: def find_and_replace(base, mods):
print "No XSD-file specified" base.xmlns_map = mods.xmlns_map
usage() recursive_add_xmlns_map(base, mods)
sys.exit(2) rm = []
for part in mods.parts:
tree = parse_nsmap(args[0]) try:
_name = part.name
except AttributeError:
continue
for _part in base.parts:
try:
if _name == _part.name:
rm.append(_part)
except AttributeError:
continue
for part in rm:
base.parts.remove(part)
base.parts.extend(mods.parts)
return base
def read_schema(doc, add, defs, impo, modul, ignore, sdir):
for path in sdir:
fil = "%s%s" % (path, doc)
try:
fp = open(fil)
fp.close()
break
except IOError as e:
if e.errno == errno.EACCES:
continue
else:
raise Exception("Could not find schema file")
tree = parse_nsmap(fil)
known = NAMESPACE_BASE[:] known = NAMESPACE_BASE[:]
known.append(XML_NAMESPACE) known.append(XML_NAMESPACE)
@@ -2072,9 +2109,79 @@ def main(argv):
continue continue
else: else:
raise Exception("Undefined namespace: %s" % namespace) raise Exception("Undefined namespace: %s" % namespace)
schema = Schema(tree._root, impo, add, modul, defs)
_schema = Schema(tree._root, impo, add, modul, defs)
_included_parts = []
_remove_parts = []
_replace = []
for part in _schema.parts:
if isinstance(part, Include):
_sch = read_schema(part.schemaLocation, add, defs, impo, modul,
ignore, sdir)
# Add namespace information
recursive_add_xmlns_map(_sch, _schema)
_included_parts.extend(_sch.parts)
_remove_parts.append(part)
elif isinstance(part, Redefine):
# This is the schema that is going to be redefined
_redef = read_schema(part.schemaLocation, add, defs, impo, modul,
ignore, sdir)
# so find and replace
# Use the schema to be redefined as starting point
_replacement = find_and_replace(_redef, part)
_replace.append((part, _replacement.parts))
for part in _remove_parts:
_schema.parts.remove(part)
_schema.parts.extend(_included_parts)
if _replace:
for vad, med in _replace:
_schema.parts.remove(vad)
_schema.parts.extend(med)
return _schema
def main(argv):
try:
opts, args = getopt.getopt(argv, "a:d:hi:I:s:",
["add=", "help", "import=", "defs="])
except getopt.GetoptError, err:
# print help information and exit:
print str(err) # will print something like "option -a not recognized"
usage()
sys.exit(2)
add = []
defs = []
impo = {}
modul = {}
ignore = []
sdir = ["./"]
for opt, arg in opts:
if opt in ("-a", "--add"):
add.append(arg)
elif opt in ("-d", "--defs"):
defs.append(arg)
elif opt in ("-h", "--help"):
usage()
sys.exit()
elif opt in ("-s", "--schemadir"):
sdir.append(arg)
elif opt in ("-i", "--import"):
mod = get_mod(arg, ['.'])
modul[mod.NAMESPACE] = mod
impo[mod.NAMESPACE] = [arg, None]
elif opt in ("-I", "--ignore"):
ignore.append(arg)
else:
assert False, "unhandled option"
if not args:
print "No XSD-file specified"
usage()
sys.exit(2)
schema = read_schema(args[0], add, defs, impo, modul, ignore, sdir)
#print schema.__dict__ #print schema.__dict__
schema.out() schema.out()