Added support for comparision types.

This commit is contained in:
Roland Hedberg
2013-04-28 09:52:57 +02:00
parent 4d138a9b38
commit 3bc4cd1d3d

View File

@@ -1,7 +1,13 @@
from saml2.saml import AuthnContext, AuthnContextClassRef
__author__ = 'rolandh' __author__ = 'rolandh'
import hashlib
from saml2 import extension_elements_to_elements from saml2 import extension_elements_to_elements
UNSPECIFIED = "urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified"
INTERNETPROTOCOLPASSWORD = \ INTERNETPROTOCOLPASSWORD = \
'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword' 'urn:oasis:names:tc:SAML:2.0:ac:classes:InternetProtocolPassword'
MOBILETWOFACTORCONTRACT = \ MOBILETWOFACTORCONTRACT = \
@@ -22,10 +28,24 @@ from saml2.authn_context import ppt
from saml2.authn_context import pword from saml2.authn_context import pword
from saml2.authn_context import sslcert from saml2.authn_context import sslcert
CMP_TYPE = ['exact', 'minimum', 'maximum', 'better']
class AuthnBroker(object): class AuthnBroker(object):
def __init__(self): def __init__(self):
self.db = {} self.db = {"info": {}, "key": {}}
def exact(self, a, b):
return a == b
def minimum(self, a, b):
return b >= a
def maximum(self, a, b):
return b <= a
def better(self, a, b):
return b > a
def add(self, spec, method, level=0, authn_authority=""): def add(self, spec, method, level=0, authn_authority=""):
""" """
@@ -41,8 +61,9 @@ class AuthnBroker(object):
""" """
if spec.authn_context_class_ref: if spec.authn_context_class_ref:
_ref = spec.authn_context_class_ref.text key = spec.authn_context_class_ref.text
self.db[_ref] = { _info = {
"class_ref": key,
"method": method, "method": method,
"level": level, "level": level,
"authn_auth": authn_authority "authn_auth": authn_authority
@@ -55,12 +76,76 @@ class AuthnBroker(object):
"level": level, "level": level,
"authn_auth": authn_authority "authn_auth": authn_authority
} }
try: else:
self.db[key].append(_info) raise NotImplementedError()
except KeyError:
self.db[key] = [_info]
def pick(self, req_authn_context): m = hashlib.md5()
for attr in ["method", "level", "authn_auth"]:
m.update(str(_info[attr]))
try:
_txt = "%s" % _info["decl"]
except KeyError:
pass
else:
m.update(_txt)
_ref = m.hexdigest()
self.db["info"][_ref] = _info
try:
self.db["key"][key].append(_ref)
except KeyError:
self.db["key"][key] = [_ref]
def remove(self, spec, method=None, level=0, authn_authority=""):
if spec.authn_context_class_ref:
_cls_ref = spec.authn_context_class_ref.text
try:
_refs = self.db["key"][_cls_ref]
except KeyError:
return
else:
_remain = []
for _ref in _refs:
item = self.db["info"][_ref]
if method and method != item["method"]:
_remain.append(_ref)
if level and level != item["level"]:
_remain.append(_ref)
if authn_authority and \
authn_authority != item["authn_authority"]:
_remain.append(_ref)
if _remain:
self.db[_cls_ref] = _remain
def _pick_by_class_ref(self, cls_ref, comparision_type="exact"):
func = getattr(self, comparision_type)
try:
_refs = self.db["key"][cls_ref]
except KeyError:
return []
else:
_item = self.db["info"][_refs[0]]
_level = _item["level"]
if _item["method"]:
res = [(_item["method"], _refs[0])]
else:
res = []
for ref in _refs[1:]:
item = self.db[ref]
res.append((item["method"], ref))
if func(_level, item["level"]):
_level = item["level"]
for ref, _dic in self.db["info"].items():
if ref in _refs:
continue
elif func(_level, _dic["level"]):
if _dic["method"]:
_val = (_dic["method"], ref)
if _val not in res:
res.append(_val)
return res
def pick(self, req_authn_context=None):
""" """
Given the authentication context find zero or more places where Given the authentication context find zero or more places where
the user could be sent next. Ordered according to security level. the user could be sent next. Ordered according to security level.
@@ -70,29 +155,25 @@ class AuthnBroker(object):
:return: An URL :return: An URL
""" """
if req_authn_context is None:
return self._pick_by_class_ref(UNSPECIFIED, "minimum")
if req_authn_context.authn_context_class_ref: if req_authn_context.authn_context_class_ref:
_ref = req_authn_context.authn_context_class_ref.text if req_authn_context.comparison:
try: _cmp = req_authn_context.comparison
_info = self.db[_ref]
except KeyError:
return []
else: else:
_level = _info["level"] _cmp = "minimum"
res = [] return self._pick_by_class_ref(
for key, _dic in self.db.items(): req_authn_context.authn_context_class_ref.text, _cmp)
if key == _ref:
continue
elif _dic["level"] >= _level:
res.append(_dic["method"])
res.insert(0, _info["method"])
return res
elif req_authn_context.authn_context_decl: elif req_authn_context.authn_context_decl:
key = req_authn_context.authn_context_decl.c_namespace _decl = req_authn_context.authn_context_decl
key = _decl.c_namespace
_methods = [] _methods = []
for _dic in self.db[key]: for _ref in self.db["key"][key]:
if self.match(req_authn_context.authn_context_decl, _dic = self.db["info"][_ref]
_dic["decl"]): if self.match(_decl, _dic["decl"]):
_methods.append(_dic["method"]) _val = (_dic["method"], _ref)
if _val not in _methods:
_methods.append(_val)
return _methods return _methods
def match(self, requested, provided): def match(self, requested, provided):
@@ -101,6 +182,9 @@ class AuthnBroker(object):
else: else:
return False return False
def __getitem__(self, ref):
return self.db["info"][ref]
def authn_context_factory(text): def authn_context_factory(text):
# brute force # brute force
@@ -118,4 +202,8 @@ def authn_context_decl_from_extension_elements(extelems):
try: try:
return res[0] return res[0]
except IndexError: except IndexError:
return None return None
def authn_context_class_ref(ref):
return AuthnContext(authn_context_class_ref=AuthnContextClassRef(text=ref))