This commit is contained in:
Roland Hedberg
2010-04-17 20:12:43 +02:00
parent 13b8dbf775
commit f28754df35
10 changed files with 528 additions and 334 deletions

View File

@@ -376,7 +376,7 @@ def make_vals(val, klass, klass_inst=None, prop=None, part=False,
else:
try:
cinst = klass().set_text(val)
except ValueError, excp:
except ValueError:
if not part:
cis = [make_vals(sval, klass, klass_inst, prop, True,
base64encode) for sval in val]
@@ -482,11 +482,11 @@ class SamlBase(ExtensionContainer):
def become_child_element_of(self, tree):
"""
Note: Only for use with classes that have a c_tag and c_namespace class
member. It is in SamlBase so that it can be inherited but it should
not be called on instances of SamlBase.
:param tree: The tree to which this instance should be a child
"""
new_child = self._to_element_tree()
tree.append(new_child)
@@ -519,15 +519,28 @@ class SamlBase(ExtensionContainer):
self.__dict__[extension_attribute_name] = value
def keyswv(self):
""" Return the keys of attributes or children that has values
:return: list of keys
"""
return [key for key, val in self.__dict__.items() if val]
def keys(self):
""" Return all the keys that represent possible attributes and
children.
:return: list of keys
"""
keys = ['text']
keys.extend(self.c_attributes.values())
keys.extend([v[1] for v in self.c_children.values()])
return keys
def children_with_values(self):
""" Returns all children that has values
:return: Possibly empty list of children.
"""
childs = []
for _, values in self.__class__.c_children.iteritems():
member = getattr(self, values[0])
@@ -541,7 +554,13 @@ class SamlBase(ExtensionContainer):
return childs
def set_text(self, val, base64encode=False):
""" """
""" Sets the text property of this instance.
:param val: The value of the text property
:param base64encode: Whether the value should be base64encoded
:return: The instance
"""
#print "set_text: %s" % (val,)
if isinstance(val, bool):
if val:
@@ -560,7 +579,19 @@ class SamlBase(ExtensionContainer):
return self
def loadd(self, ava, base64encode=False):
""" """
"""
Sets attributes, children, extension elements and extension
attributes of this element instance depending on what is in
the given dictionary. If there are already values on properties
those will be overwritten. If the keys in the dictionary does
not correspond to known attributes/children/.. they are ignored.
:param ava: The dictionary
:param base64encode: Whether the values on attributes or texts on
children shoule be base64encoded.
:return: The instance
"""
for prop in self.c_attributes.values():
#print "# %s" % (prop)
if prop in ava:
@@ -578,7 +609,8 @@ class SamlBase(ExtensionContainer):
#print "## %s, %s" % (prop, klassdef)
if prop in ava:
#print "### %s" % ava[prop]
if isinstance(klassdef, list): # means there can be a list of values
# means there can be a list of values
if isinstance(klassdef, list):
make_vals(ava[prop], klassdef[0], self, prop,
base64encode=base64encode)
else:
@@ -588,7 +620,8 @@ class SamlBase(ExtensionContainer):
if "extension_elements" in ava:
for item in ava["extension_elements"]:
self.extension_elements.append(ExtensionElement(item["tag"]).loadd(item))
self.extension_elements.append(ExtensionElement(
item["tag"]).loadd(item))
if "extension_attributes" in ava:
for key, val in ava["extension_attributes"].items():
@@ -598,22 +631,39 @@ class SamlBase(ExtensionContainer):
def element_to_extension_element(element):
ee = ExtensionElement(element.c_tag, element.c_namespace,
"""
Convert an element into a extension element
:param element: The element instance
:return: An extension element instance
"""
exel = ExtensionElement(element.c_tag, element.c_namespace,
text=element.text)
for xml_attribute, member_name in element.c_attributes.iteritems():
member_value = getattr(element, member_name)
if member_value is not None:
ee.attributes[xml_attribute] = member_value
exel.attributes[xml_attribute] = member_value
ee.children = [element_to_extension_element(c) \
exel.children = [element_to_extension_element(c) \
for c in element.children_with_values()]
return ee
return exel
def extension_element_to_element(extension_element, translation_functions,
namespace=None):
""" """
""" Convert an extension element to a normal element.
In order to do this you need to have an idea of what type of
element it is. Or rather which module it belongs to.
:param extension_element: The extension element
:prama translation_functions: A dictionary which klass identifiers
as keys and string-to-element translations functions as values
:param namespace: The namespace of the translation functions.
:return: An element instance or None
"""
try:
element_namespace = extension_element.namespace
except AttributeError:

View File

@@ -35,10 +35,10 @@ def _filter_values(vals, required=None, optional=None):
:param optional: The optional values
:return: The set of values after filtering
"""
if not required and not optional:
return vals
valr = []
valo = []
if required:
@@ -54,17 +54,17 @@ def _filter_values(vals, required=None, optional=None):
valr.append(val)
elif val in ovals:
valo.append(val)
valo.extend(valr)
if rvals:
if len(rvals) == len(valr):
return valo
else:
a = set(rvals)
_ = set(rvals)
raise MissingValue("Required attribute value missing")
else:
return valo
def _combine(required=None, optional=None):
res = {}
if not required:
@@ -81,12 +81,12 @@ def _combine(required=None, optional=None):
res[(attr.name, attr.friendly_name)] = part
else:
res[(attr.name, attr.friendly_name)] = (attr.attribute_value, [])
for oat in optional:
tag = (oat.name, oat.friendly_name)
if tag not in res:
res[tag] = ([], oat.attribute_value)
return res
def filter_on_attributes(ava, required=None, optional=None):
@@ -102,24 +102,48 @@ def filter_on_attributes(ava, required=None, optional=None):
res[attr[1]] = _filter_values(ava[attr[1]], vals[0], vals[1])
else:
print >> sys.stderr, ava.keys()
raise MissingValue("Required attribute missing: '%s'" % (attr,))
raise MissingValue("Required attribute missing: '%s'" % (attr[0],))
return res
def filter_on_demands(ava, required={}, optional={}):
""" Never return more than is needed """
# Is all what's required there:
for attr, vals in required.items():
if attr in ava:
if vals:
for val in vals:
if val not in ava[attr]:
raise MissingValue(
"Required attribute value missing: %s,%s" % (attr,
val))
else:
raise MissingValue("Required attribute missing: %s" % (attr,))
# OK, so I can imaging releasing values that are not absolutely necessary
# but not attributes
for attr, vals in ava.items():
if attr not in required and attr not in optional:
del ava[attr]
return ava
def filter_attribute_value_assertions(ava, attribute_restrictions=None):
""" Will weed out attribute values and values according to the
rules defined in the attribute restrictions. If filtering results in
""" Will weed out attribute values and values according to the
rules defined in the attribute restrictions. If filtering results in
an attribute without values, then the attribute is removed from the
assertion.
:param ava: The incoming attribute value assertion
:param ava: The incoming attribute value assertion (dictionary)
:param attribute_restrictions: The rules that govern which attributes
and values that are allowed.
and values that are allowed. (dictionary)
:return: The modified attribute value assertion
"""
if not attribute_restrictions:
return ava
for attr, vals in ava.items():
if attr in attribute_restrictions:
if attribute_restrictions[attr]:
@@ -136,7 +160,7 @@ def filter_attribute_value_assertions(ava, attribute_restrictions=None):
else:
del ava[attr]
return ava
class Policy(object):
""" handles restrictions on assertions """
@@ -145,11 +169,11 @@ class Policy(object):
self.compile(restrictions)
else:
self._restrictions = None
def compile(self, restrictions):
""" This is only for IdPs or AAs, and it's about limiting what
is returned to the SP.
In the configuration file, restrictions on which values that
is returned to the SP.
In the configuration file, restrictions on which values that
can be returned are specified with the help of regular expressions.
This function goes through and pre-compiles the regular expressions.
@@ -163,25 +187,25 @@ class Policy(object):
for _, spec in self._restrictions.items():
if spec == None:
continue
try:
restr = spec["attribute_restrictions"]
except KeyError:
continue
if restr == None:
continue
for key, values in restr.items():
if not values:
spec["attribute_restrictions"][key] = None
continue
spec["attribute_restrictions"][key] = \
[re.compile(value) for value in values]
return self._restrictions
def get_nameid_format(self, sp_entity_id):
try:
form = self._restrictions[sp_entity_id]["nameid_format"]
@@ -190,7 +214,7 @@ class Policy(object):
form = self._restrictions["default"]["nameid_format"]
except KeyError:
form = saml.NAMEID_FORMAT_TRANSIENT
return form
def get_name_form(self, sp_entity_id):
@@ -203,7 +227,7 @@ class Policy(object):
form = self._restrictions["default"]["name_form"]
except KeyError:
pass
return form
def get_lifetime(self, sp_entity_id):
@@ -211,7 +235,7 @@ class Policy(object):
spec = {"hours":1}
if not self._restrictions:
return spec
try:
spec = self._restrictions[sp_entity_id]["lifetime"]
except KeyError:
@@ -219,13 +243,13 @@ class Policy(object):
spec = self._restrictions["default"]["lifetime"]
except KeyError:
pass
return spec
return spec
def get_attribute_restriction(self, sp_entity_id):
if not self._restrictions:
return None
try:
try:
restrictions = self._restrictions[sp_entity_id][
@@ -238,9 +262,9 @@ class Policy(object):
restrictions = None
except KeyError:
restrictions = None
return restrictions
return restrictions
def _not_on_or_after(self, sp_entity_id):
""" When the assertion stops being valid, should not be
used after this time.
@@ -249,13 +273,13 @@ class Policy(object):
"""
return in_a_while(**self.get_lifetime(sp_entity_id))
def filter(self, ava, sp_entity_id, required=None, optional=None):
""" What attribute and attribute values returns depends on what
the SP has said it wants in the request or in the metadata file and
what the IdP/AA wants to release. An assumption is that what the SP
asks for overrides whatever is in the metadata. But of course the
IdP never releases anything it doesn't want to.
asks for overrides whatever is in the metadata. But of course the
IdP never releases anything it doesn't want to.
:param ava: The information about the subject as a dictionary
:param sp_entity_id: The entity ID of the SP
@@ -263,18 +287,18 @@ class Policy(object):
:param optional: Attributes that the SP thinks is optional
:return: A possibly modified AVA
"""
ava = filter_attribute_value_assertions(ava,
self.get_attribute_restriction(sp_entity_id))
ava = filter_attribute_value_assertions(ava,
self.get_attribute_restriction(sp_entity_id))
if required or optional:
ava = filter_on_attributes(ava, required, optional)
return ava
def restrict(self, ava, sp_entity_id, metadata=None):
""" Identity attribute names are expected to be expressed in
""" Identity attribute names are expected to be expressed in
the local lingo (== friendlyName)
:return: A filtered ava according to the IdPs/AAs rules and
@@ -283,16 +307,17 @@ class Policy(object):
"""
if metadata:
(required, optional) = metadata.attribute_consumer(sp_entity_id)
#(required, optional) = metadata.wants(sp_entity_id)
else:
required = optional = None
return self.filter(ava, sp_entity_id, required, optional)
def conditions(self, sp_entity_id):
return args2dict(
not_before=instant(),
not_before=instant(),
# How long might depend on who's getting it
not_on_or_after=self._not_on_or_after(sp_entity_id),
not_on_or_after=self._not_on_or_after(sp_entity_id),
audience_restriction=args2dict(
audience=args2dict(sp_entity_id)))
@@ -301,16 +326,16 @@ class Assertion(dict):
def __init__(self, dic=None):
dict.__init__(self, dic)
def _authn_statement(self):
return args2dict(authn_instant=instant(), session_index=sid())
def construct(self, sp_entity_id, in_response_to, name_id, attrconvs,
policy, issuer):
def construct(self, sp_entity_id, in_response_to, name_id, attrconvs,
policy, issuer):
attr_statement = from_local(attrconvs, self,
policy.get_name_form(sp_entity_id))
# start using now and for a hour
conds = policy.conditions(sp_entity_id)
@@ -326,6 +351,6 @@ class Assertion(dict):
subject_confirmation_data = \
args2dict(in_response_to=in_response_to))),
)
def apply_policy(self, sp_entity_id, policy, metadata=None):
return policy.restrict(self, sp_entity_id, metadata)

View File

@@ -17,6 +17,7 @@
import os
from saml2.utils import args2dict
from saml2.saml import NAME_FORMAT_URI
class UnknownNameFormat(Exception):
pass
@@ -26,44 +27,82 @@ def ac_factory(path):
for tup in os.walk(path):
if tup[2]:
ac = AttributeConverter(os.path.basename(tup[0]))
atco = AttributeConverter(os.path.basename(tup[0]))
for name in tup[2]:
fname = os.path.join(tup[0], name)
if name.endswith(".py"):
name = name[:-3]
ac.set(name,fname)
ac.adjust()
acs.append(ac)
atco.set(name, fname)
atco.adjust()
acs.append(atco)
return acs
def ava_fro(acs, statement):
""" translates attributes according to their name_formats """
if statement == []:
return {}
acsdic = dict([(ac.format, ac) for ac in acs])
acsdic[None] = acsdic[NAME_FORMAT_URI]
return dict([acsdic[a.name_format].ava_fro(a) for a in statement])
def to_local(acs, statement):
if not acs:
acs = [AttributeConverter()]
ava = []
for ac in acs:
for aconv in acs:
try:
ava = ac.fro(statement)
ava = aconv.fro(statement)
break
except UnknownNameFormat:
pass
return ava
def from_local(acs, ava, name_format):
for ac in acs:
for aconv in acs:
#print ac.format, name_format
if ac.format == name_format:
if aconv.format == name_format:
#print "Found a name_form converter"
return ac.to(ava)
return aconv.to(ava)
return None
def from_local_name(acs, attr, name_format):
"""
:param acs: List of AttributeConverter instances
:param attr: attribute name as string
:param name_format: Which name-format it should be translated to
:return: A dictionary suitable to feed to make_instance
"""
for aconv in acs:
#print ac.format, name_format
if aconv.format == name_format:
#print "Found a name_form converter"
return aconv.to_format(attr)
return attr
def to_local_name(acs, attr):
"""
:param acs: List of AttributeConverter instances
:param attr: an Attribute instance
:return: The local attribute name
"""
for aconv in acs:
lattr = aconv.from_format(attr)
if lattr:
return lattr
return attr.friendly_name
class AttributeConverter(object):
""" Converts from an attribute statement to a key,value dictionary and
vice-versa """
def __init__(self, format=""):
self.format = format
def __init__(self, name_format=""):
self.name_format = name_format
self._to = None
self._fro = None
def set(self, name, filename):
if name == "to":
@@ -101,6 +140,24 @@ class AttributeConverter(object):
result[name].append(value.text.strip())
return result
def ava_fro(self, attribute):
try:
attr = self._fro[attribute.name.strip()]
except (AttributeError, KeyError):
try:
attr = attribute.friendly_name.strip()
except AttributeError:
attr = attribute.name.strip()
val = []
for value in attribute.attribute_value:
if not value.text:
val.append('')
else:
val.append(value.text.strip())
return (attr, val)
def fro(self, statement):
""" Get the attributes and the attribute values
@@ -108,41 +165,48 @@ class AttributeConverter(object):
:return: A dictionary containing attributes and values
"""
if not self.format:
if not self.name_format:
return self.fail_safe_fro(statement)
result = {}
for attribute in statement.attribute:
if attribute.name_format and self.format and \
attribute.name_format != self.format:
if attribute.name_format and self.name_format and \
attribute.name_format != self.name_format:
raise UnknownNameFormat
try:
name = self._fro[attribute.name.strip()]
except (AttributeError, KeyError):
try:
name = attribute.friendly_name.strip()
except AttributeError:
name = attribute.name.strip()
result[name] = []
for value in attribute.attribute_value:
if not value.text:
result[name].append('')
else:
result[name].append(value.text.strip())
(key, val) = self.ava_fro(attribute)
result[key] = val
if not result:
return self.fail_safe_fro(statement)
else:
return result
def to_format(self, attr):
try:
return args2dict(name=self._to[attr], name_format=self.name_format,
friendly_name=attr)
except KeyError:
return args2dict(name=attr)
def from_format(self, attr):
"""
:param attr: An saml.Attribute instance
:return: The local attribute name or "" if no mapping could be made
"""
if self.name_format == attr.name_format:
try:
return self._fro[attr.name]
except KeyError:
pass
return ""
def to(self, ava):
attributes = []
for key, value in ava.items():
try:
attributes.append(args2dict(name=self._to[key],
name_format=self.format,
name_format=self.name_format,
friendly_name=key,
attribute_value=value))
except KeyError:

View File

@@ -20,7 +20,10 @@ import time
from saml2.time_util import str_to_time
from saml2 import samlp, saml
from saml2 import samlp
from saml2 import saml
from saml2 import extension_element_to_element
from saml2.sigver import security_context
from saml2.attribute_converter import to_local
@@ -46,7 +49,7 @@ def _use_before(condition, slack):
#print "NOW: %s" % now
not_before = time.mktime(str_to_time(condition.not_before))
#print "not_before: %d" % not_before
if not_before > now + slack:
# Can't use it yet
raise Exception("Can't use it yet %s <= %s" % (not_before, now))
@@ -63,15 +66,15 @@ def for_me(condition, myself ):
pass
return False
# ---------------------------------------------------------------------------
class IncorrectlySigned(Exception):
pass
# ---------------------------------------------------------------------------
def authn_response(conf, requestor, outstanding_queries=None, log=None,
def authn_response(conf, requestor, outstanding_queries=None, log=None,
timeslack=0, debug=0):
sec = security_context(conf)
if not timeslack:
@@ -79,15 +82,15 @@ def authn_response(conf, requestor, outstanding_queries=None, log=None,
timeslack = int(conf["timeslack"])
except KeyError:
pass
return AuthnResponse(sec, conf.attribute_converters, requestor,
outstanding_queries, log, timeslack, debug)
return AuthnResponse(sec, conf.attribute_converters, requestor,
outstanding_queries, log, timeslack, debug)
class AuthnResponse(object):
def __init__(self, security_context, attribute_converters, requestor,
def __init__(self, sec_context, attribute_converters, requestor,
outstanding_queries=None, log=None, timeslack=0, debug=0):
self.sc = security_context
self.sec = sec_context
self.attribute_converters = attribute_converters
self.requestor = requestor
if outstanding_queries:
@@ -100,9 +103,15 @@ class AuthnResponse(object):
self.debug = debug
if self.debug and not self.log:
self.debug = 0
self.clear()
self.xmlstr = ""
self.xmlstr = ""
self.came_from = ""
self.name_id = ""
self.ava = None
self.response = None
self.not_on_or_after = 0
self.assertion = None
def loads(self, xmldata, decode=True):
if self.debug:
self.log.info("--- Loads AuthnResponse ---")
@@ -112,11 +121,11 @@ class AuthnResponse(object):
decoded_xml = xmldata
# own copy
self.xmlstr = decoded_xml[:]
self.xmlstr = decoded_xml[:]
if self.debug:
self.log.info("xmlstr: %s" % (self.xmlstr,))
try:
self.response = self.sc.correctly_signed_response(decoded_xml)
self.response = self.sec.correctly_signed_response(decoded_xml)
except Exception, excp:
self.log and self.log.info("EXCEPTION: %s", excp)
raise
@@ -126,12 +135,12 @@ class AuthnResponse(object):
self.log.error("Response was not correctly signed")
self.log.info(decoded_xml)
raise IncorrectlySigned()
if self.debug:
self.log.info("response: %s" % (self.response,))
return self
return self
def clear(self):
self.xmlstr = ""
self.came_from = ""
@@ -140,7 +149,7 @@ class AuthnResponse(object):
self.response = None
self.not_on_or_after = 0
self.assertion = None
def status_ok(self):
if self.response.status:
status = self.response.status
@@ -152,13 +161,13 @@ class AuthnResponse(object):
raise Exception(
"Not successfull according to: %s" % \
status.status_code.value)
def authn_statement_ok(self):
# the assertion MUST contain one AuthNStatement
assert len(self.assertion.authn_statement) == 1
# authn_statement = assertion.authn_statement[0]
# check authn_statement.session_index
def condition_ok(self, lax=False):
# The Identity Provider MUST include a <saml:Conditions> element
#print "Conditions",assertion.conditions
@@ -166,23 +175,23 @@ class AuthnResponse(object):
condition = self.assertion.conditions
if self.debug:
self.log.info("condition: %s" % condition)
try:
self.not_on_or_after = _use_on_or_after(condition, self.timeslack)
_use_before(condition, self.timeslack)
except Exception,excp:
except Exception, excp:
self.log.error("Exception on condition: %s" % (excp,))
if not lax:
raise
else:
self.not_on_or_after = 0
if not for_me(condition, self.requestor):
if not lax:
raise Exception("Not for me!!!")
return True
def get_identity(self):
# The assertion can contain zero or one attributeStatements
if not self.assertion.attribute_statement:
@@ -194,13 +203,13 @@ class AuthnResponse(object):
if self.debug:
self.log.info("Attribute Statement: %s" % (
self.assertion.attribute_statement[0],))
for ac in self.attribute_converters():
self.log.info("Converts name format: %s" % (ac.format,))
for aconv in self.attribute_converters():
self.log.info("Converts name format: %s" % (aconv.format,))
ava = to_local(self.attribute_converters(),
self.assertion.attribute_statement[0])
return ava
def get_subject(self):
# The assertion must contain a Subject
assert self.assertion.subject
@@ -221,7 +230,7 @@ class AuthnResponse(object):
# The subject must contain a name_id
assert subject.name_id
self.name_id = subject.name_id.text.strip()
def _assertion(self, assertion):
self.assertion = assertion
@@ -233,24 +242,24 @@ class AuthnResponse(object):
if self.context == "AuthNReq":
self.authn_statement_ok()
if not self.condition_ok():
return None
if self.debug:
self.log.info("--- Getting Identity ---")
self.ava = self.get_identity()
if self.debug:
self.log.info("--- AVA: %s" % (self.ava,))
self.get_subject()
return True
def _encrypted_assertion(self, xmlstr):
decrypt_xml = self.sc.decrypt(self.xmlstr)
decrypt_xml = self.sec.decrypt(xmlstr)
if self.debug:
self.log.info("Decryption successfull")
@@ -259,57 +268,66 @@ class AuthnResponse(object):
if self.debug:
self.log.info("Parsed decrypted assertion successfull")
enc = self.response.encrypted_assertion[0].extension_elements[0]
assertion = extension_element_to_element(enc,
enc = self.response.encrypted_assertion[0].extension_elements[0]
assertion = extension_element_to_element(enc,
saml.ELEMENT_FROM_STRING,
namespace=saml.NAMESPACE)
if self.debug:
self.log.info("Decrypted Assertion: %s" % assertion)
return self._assertion(assertion)
def parse_assertion(self):
try:
assert len(self.response.assertion) == 1 or \
len(self.response.encrypted_assertion) == 1
except AssertionError:
raise Exception("No assertion part")
if self.response.assertion:
if self.response.assertion:
self.debug and self.log.info("***Unencrypted response***")
return self._assertion(self.response.assertion[0])
else:
self.debug and self.log.info("***Encrypted response***")
return self._encrypted_assertion(
self.response.encrypted_assertion[0])
return True
return True
def verify(self):
""" """
""" Verify that the assertion is syntaktically correct and
the signature is correct if present."""
self.status_ok()
if self.parse_assertion():
return self
else:
return None
def issuer(self):
""" Return the issuer of the reponse """
return self.response.issuer.text
def session_id(self):
""" Returns the SessionID of the response """
return self.response.in_response_to
def id(self):
""" Return the ID of the response """
return self.response.id
def session_info(self):
return { "ava": self.ava, "name_id": self.name_id,
""" Returns a predefined set of information gleened from the
response.
:returns: Dictionary with information
"""
return { "ava": self.ava, "name_id": self.name_id,
"came_from": self.came_from, "issuer": self.issuer(),
"not_on_or_after": self.not_on_or_after }
def __str__(self):
return "%s" % self.xmlstr
# ======================================================================
# session_info["ava"]["__userid"] = session_info["name_id"]
# return session_info

View File

@@ -23,19 +23,17 @@ import os
import urllib
import saml2
import base64
import time
import sys
from saml2.time_util import str_to_time, instant
from saml2.time_util import instant
from saml2.utils import sid, deflate_and_base64_encode
from saml2.utils import do_attributes, args2dict
from saml2 import samlp, saml, extension_element_to_element
from saml2 import VERSION, class_name, make_instance
from saml2 import samlp, saml
from saml2 import VERSION, make_instance
from saml2.sigver import pre_signature_part
from saml2.sigver import security_context, signed_instance_factory
from saml2.soap import SOAPClient
from saml2.attribute_converter import to_local
from saml2.authnresponse import authn_response
DEFAULT_BINDING = saml2.BINDING_HTTP_REDIRECT
@@ -61,51 +59,52 @@ class Saml2Client(object):
self.config = config
if "metadata" in config:
self.metadata = config["metadata"]
self.sc = security_context(config)
self.sec = security_context(config)
self.debug = debug
def _init_request(self, request, destination):
#request.id = sid()
request.version = VERSION
request.issue_instant = instant()
request.destination = destination
return request
return request
def idp_entry(self, name=None, location=None, provider_id=None):
res = {}
if name:
if name:
res["name"] = name
if location:
if location:
res["loc"] = location
if provider_id:
if provider_id:
res["provider_id"] = provider_id
if res:
return res
else:
return None
def scoping(self, idp_ents):
return {
"idp_list": {
"idp_entry": idp_ents
}
}
def scoping_from_metadata(self, entityid, location):
name = self.metadata.name(entityid)
return make_instance(self.scoping([self.idp_entry(name, location)]))
return make_instance(samlp.Scoping,
self.scoping([self.idp_entry(name, location)]))
def response(self, post, requestor, outstanding, log=None):
""" Deal with the AuthnResponse
:param post: The reply as a dictionary
:param requestor: The issuer of the AuthN request
:param outstanding: A dictionary with session IDs as keys and
:param outstanding: A dictionary with session IDs as keys and
the original web request from the user before redirection
as values.
:param log: where loggin should go.
:return: A 2-tuple of identity information (in the form of a
:return: A 2-tuple of identity information (in the form of a
dictionary) and where the user should really be sent. This
might differ from what the IdP thinks since I don't want
to reveal verything to it and it might not trust me.
@@ -115,18 +114,18 @@ class Saml2Client(object):
saml_response = post['SAMLResponse']
except KeyError:
return None
if saml_response:
ar = authn_response(self.config, requestor, outstanding, log,
aresp = authn_response(self.config, requestor, outstanding, log,
debug=self.debug)
ar.loads(saml_response)
aresp.loads(saml_response)
if self.debug:
log and log.info(ar)
return ar.verify()
log and log.info(aresp)
return aresp.verify()
return None
def authn_request(self, query_id, destination, service_url, spentityid,
def authn_request(self, query_id, destination, service_url, spentityid,
my_name, vorg="", scoping=None, log=None, sign=False):
""" Creates an authentication request.
@@ -152,7 +151,7 @@ class Saml2Client(object):
if scoping:
prel["scoping"] = scoping
name_id_policy = {
"allow_create": "true"
}
@@ -166,37 +165,37 @@ class Saml2Client(object):
pass
if sign:
prel["signature"] = pre_signature_part(prel["id"],
self.sc.my_cert, id=1)
prel["signature"] = pre_signature_part(prel["id"],
self.sec.my_cert, id=1)
prel["name_id_policy"] = name_id_policy
prel["issuer"] = { "text": spentityid }
if log:
log.info("DICT VERSION: %s" % prel)
return "%s" % signed_instance_factory(samlp.AuthnRequest, prel,
self.sc)
def authenticate(self, spentityid, location="", service_url="",
return "%s" % signed_instance_factory(samlp.AuthnRequest, prel,
self.sec)
def authenticate(self, spentityid, location="", service_url="",
my_name="", relay_state="",
binding=saml2.BINDING_HTTP_REDIRECT, log=None,
vorg="", scoping=None):
""" Sends an authentication request.
:param spentityid: The SP EntityID
:param binding: How the authentication request should be sent to the
:param binding: How the authentication request should be sent to the
IdP
:param location: Where the IdP is.
:param service_url: The SP's service URL
:param my_name: The providers name
:param relay_state: To where the user should be returned after
:param relay_state: To where the user should be returned after
successfull log in.
:param binding: Which binding to use for sending the request
:param log: Where to write log messages
:param vorg: The entity_id of the virtual organization I'm a member of
:param scoping: For which IdPs this query are aimed.
:return: AuthnRequest response
"""
@@ -206,8 +205,8 @@ class Saml2Client(object):
log.info("service_url: %s" % service_url)
log.info("my_name: %s" % my_name)
session_id = sid()
authen_req = self.authn_request(session_id, location,
service_url, spentityid, my_name, vorg,
authen_req = self.authn_request(session_id, location,
service_url, spentityid, my_name, vorg,
scoping, log)
log and log.info("AuthNReq: %s" % authen_req)
@@ -238,32 +237,34 @@ class Saml2Client(object):
else:
raise Exception("Unkown binding type: %s" % binding)
return (session_id, response)
def create_attribute_query(self, session_id, subject_id, issuer,
destination, attribute=None, sp_name_qualifier=None,
name_qualifier=None, nameformat=None, sign=False):
""" Constructs an AttributeQuery
def create_attribute_query(self, session_id, subject_id, issuer,
destination, attribute=None, sp_name_qualifier=None,
name_qualifier=None, nameid_format=None, sign=False):
""" Constructs an AttributeQuery
:param subject_id: The identifier of the subject
:param destination: To whom the query should be sent
:param attribute: A dictionary of attributes and values that is
:param attribute: A dictionary of attributes and values that is
asked for. The key are one of 4 variants:
3-tuple of name_format,name and friendly_name,
2-tuple of name_format and name,
1-tuple with name or
1-tuple with name or
just the name as a string.
:param sp_name_qualifier: The unique identifier of the
service provider or affiliation of providers for whom the
:param sp_name_qualifier: The unique identifier of the
service provider or affiliation of providers for whom the
identifier was generated.
:param name_qualifier: The unique identifier of the identity
:param name_qualifier: The unique identifier of the identity
provider that generated the identifier.
:param nameid_format: The format of the name ID
:param sign: Whether the query should be signed or not.
:return: An AttributeQuery instance
"""
subject = args2dict(
name_id = args2dict(subject_id, format=nameformat,
name_id = args2dict(subject_id, format=nameid_format,
sp_name_qualifier=sp_name_qualifier,
name_qualifier=name_qualifier),
)
@@ -278,45 +279,46 @@ class Saml2Client(object):
}
if sign:
prequery["signature"] = pre_signature_part(prequery["id"],
self.sc.my_cert, 1)
prequery["signature"] = pre_signature_part(prequery["id"],
self.sec.my_cert, 1)
if attribute:
prequery["attribute"] = do_attributes(attribute)
request = make_instance(samlp.AttributeQuery, prequery)
if sign:
signed_req = self.sc.sign_assertion_using_xmlsec("%s" % request)
signed_req = self.sec.sign_assertion_using_xmlsec("%s" % request)
return samlp.attribute_query_from_string(signed_req)
else:
return request
def attribute_query(self, subject_id, issuer, destination,
attribute=None, sp_name_qualifier=None, name_qualifier=None,
format=None, log=None):
def attribute_query(self, subject_id, issuer, destination,
attribute=None, sp_name_qualifier=None, name_qualifier=None,
nameid_format=None, log=None):
""" Does a attribute request from an attribute authority
:param subject_id: The identifier of the subject
:param destination: To whom the query should be sent
:param attribute: A dictionary of attributes and values that is asked for
:param sp_name_qualifier: The unique identifier of the
service provider or affiliation of providers for whom the
:param sp_name_qualifier: The unique identifier of the
service provider or affiliation of providers for whom the
identifier was generated.
:param name_qualifier: The unique identifier of the identity
:param name_qualifier: The unique identifier of the identity
provider that generated the identifier.
:param nameid_format: The format of the name ID
:return: The attributes returned
"""
session_id = sid()
request = self.create_attribute_query(session_id, subject_id,
issuer, destination, attribute, sp_name_qualifier,
name_qualifier, nameformat=format)
request = self.create_attribute_query(session_id, subject_id,
issuer, destination, attribute, sp_name_qualifier,
name_qualifier, nameid_format=nameid_format)
log and log.info("Request, created: %s" % request)
soapclient = SOAPClient(destination, self.config["key_file"],
soapclient = SOAPClient(destination, self.config["key_file"],
self.config["cert_file"])
log and log.info("SOAP client initiated")
try:
@@ -324,32 +326,32 @@ class Saml2Client(object):
except Exception, exc:
log and log.info("SoapClient exception: %s" % (exc,))
return None
log and log.info("SOAP request sent and got response: %s" % response)
if response:
log and log.info("Verifying response")
ar = authn_response(self.config, issuer, {session_id:""}, log)
session_info = ar.loads(response).verify().session_info()
aresp = authn_response(self.config, issuer, {session_id:""}, log)
session_info = aresp.loads(response).verify().session_info()
log and log.info("session: %s" % session_info)
return session_info
else:
log and log.info("No response")
return None
def make_logout_request(self, session_id, destination, issuer,
reason=None, not_on_or_after=None):
""" Constructs a LogoutRequest
:param subject_id: The identifier of the subject
:param reason: An indication of the reason for the logout, in the
:param reason: An indication of the reason for the logout, in the
form of a URI reference.
:param not_on_or_after: The time at which the request expires,
:param not_on_or_after: The time at which the request expires,
after which the recipient may discard the message.
:return: An AttributeQuery instance
"""
prel = {
"id": sid(),
"version": VERSION,
@@ -358,21 +360,21 @@ class Saml2Client(object):
"issuer": issuer,
"session_index": session_id,
}
if reason:
prel["reason"] = reason
if not_on_or_after:
prel["not_on_or_after"] = not_on_or_after
return make_instance(samlp.LogoutRequest, prel)
return make_instance(samlp.LogoutRequest, prel)
def logout(self, session_id, destination,
issuer, reason="", not_on_or_after=None):
issuer, reason="", not_on_or_after=None):
return self.make_logout_request(session_id, destination,
issuer, reason, not_on_or_after)
# ----------------------------------------------------------------------
ROW = """<tr><td>%s</td><td>%s</td></tr>"""
@@ -396,7 +398,7 @@ def _print_statement(statem):
txt.append(ROW % (key, _print_statement(val)))
else:
txt.append(ROW % (key, val))
txt.append("</table>")
return "\n".join(txt)

View File

@@ -2,10 +2,9 @@
# -*- coding: utf-8 -*-
#
from saml2 import metadata, utils, saml
from saml2 import metadata
from saml2.assertion import Policy
from saml2.attribute_converter import ac_factory, AttributeConverter
import re
class MissingValue(Exception):
pass
@@ -22,24 +21,24 @@ def entity_id2url(meta, entity_id):
return meta.single_sign_on_services(entity_id)[0]
class Config(dict):
def sp_check(self, config, metadata=None):
""" config["idp"] is a dictionary with entity_ids as keys and
urls as values
def _sp_check(self, config, metadat=None):
""" Verify that the SP configuration part is correct.
"""
if metadata:
if metadat:
if "idp" not in config or len(config["idp"]) == 0:
eids = [e for e, d in metadata.entity.items() if "idp_sso" in d]
eids = [e for e, d in metadat.entity.items() if "idp_sso" in d]
config["idp"] = {}
for eid in eids:
try:
config["idp"][eid] = entity_id2url(metadata, eid)
except IndexError, KeyError:
config["idp"][eid] = entity_id2url(metadat, eid)
except (IndexError, KeyError):
if not config["idp"][eid]:
raise MissingValue
else:
for eid, url in config["idp"].items():
if not url:
config["idp"][eid] = entity_id2url(metadata, eid)
config["idp"][eid] = entity_id2url(metadat, eid)
else:
assert "idp" in config
assert len(config["idp"]) > 0
@@ -47,7 +46,7 @@ class Config(dict):
assert "url" in config
assert "name" in config
def idp_aa_check(self, config):
def _idp_aa_check(self, config):
assert "url" in config
if "assertions" in config:
config["policy"] = Policy(config["assertions"])
@@ -55,9 +54,9 @@ class Config(dict):
elif "policy" in config:
config["policy"] = Policy(config["policy"])
def load_metadata(self, metadata_conf, xmlsec_binary):
def load_metadata(self, metadata_conf, xmlsec_binary, acs):
""" Loads metadata into an internal structure """
metad = metadata.MetaData(xmlsec_binary)
metad = metadata.MetaData(xmlsec_binary, acs)
if "local" in metadata_conf:
for mdfile in metadata_conf["local"]:
metad.import_metadata(open(mdfile).read(), mdfile)
@@ -86,26 +85,27 @@ class Config(dict):
else:
config["key_file"] = None
if "metadata" in config:
config["metadata"] = self.load_metadata(config["metadata"],
config["xmlsec_binary"])
if "attribute_map_dir" in config:
config["attrconverters"] = ac_factory(
config["attribute_map_dir"])
else:
config["attrconverters"] = [AttributeConverter()]
if "metadata" in config:
config["metadata"] = self.load_metadata(config["metadata"],
config["xmlsec_binary"],
config["attrconverters"])
if "sp" in config["service"]:
#print config["service"]["sp"]
if "metadata" in config:
self.sp_check(config["service"]["sp"], config["metadata"])
self._sp_check(config["service"]["sp"], config["metadata"])
else:
self.sp_check(config["service"]["sp"])
self._sp_check(config["service"]["sp"])
if "idp" in config["service"]:
self.idp_aa_check(config["service"]["idp"])
self._idp_aa_check(config["service"]["idp"])
if "aa" in config["service"]:
self.idp_aa_check(config["service"]["aa"])
self._idp_aa_check(config["service"]["aa"])
for key, val in config.items():
self[key] = val

View File

@@ -20,7 +20,7 @@ Contains classes and functions to alleviate the handling of SAML metadata
"""
import httplib2
import sys
import sys
from decorator import decorator
from saml2 import md, BINDING_HTTP_POST
@@ -28,32 +28,35 @@ from saml2 import samlp, BINDING_HTTP_REDIRECT, BINDING_SOAP
#from saml2.time_util import str_to_time
from saml2.sigver import make_temp, cert_from_key_info, verify_signature
from saml2.time_util import valid
from saml2.attribute_converter import ava_fro
@decorator
def keep_updated(f, self, entity_id, *args, **kwargs):
def keep_updated(func, self, entity_id, *args, **kwargs):
#print "In keep_updated"
try:
if not valid(self.entity[entity_id]["valid_until"]):
self.reload_entity(entity_id)
except KeyError:
pass
return f(self, entity_id, *args, **kwargs)
return func(self, entity_id, *args, **kwargs)
class MetaData(object):
""" A class to manage metadata information """
def __init__(self, xmlsec_binary=None, log=None):
def __init__(self, xmlsec_binary=None, attrconv=None, log=None):
self.log = log
self.xmlsec_binary = xmlsec_binary
self.attrconv = attrconv or []
self._loc_key = {}
self._loc_bind = {}
self.entity = {}
self.valid_to = None
self.cache_until = None
self.log = log
self.xmlsec_binary = xmlsec_binary
self.http = httplib2.Http()
self._import = {}
self._wants = {}
def _vo_metadata(self, entity_descriptor, entity, tag):
"""
Pick out the Affiliation descriptors from an entity
@@ -68,10 +71,10 @@ class MetaData(object):
return
members = []
for tafd in afd: # should really never be more than one
for tafd in afd: # should really never be more than one
members.extend(
[member.text.strip() for member in tafd.affiliate_member])
if members != []:
entity[tag] = members
@@ -89,6 +92,9 @@ class MetaData(object):
return
ssds = []
required = []
optional = []
#print "..... %s ..... " % entity_descriptor.entity_id
for tssd in ssd:
# Only want to talk to SAML 2.0 entities
if samlp.NAMESPACE not in \
@@ -102,9 +108,26 @@ class MetaData(object):
certs.extend(cert_from_key_info(key_desc.key_info))
certs = [make_temp(c, suffix=".der") for c in certs]
for acs in tssd.attribute_consuming_service:
for attr in acs.requested_attribute:
print "==", attr
if attr.is_required == "true":
required.append(attr)
else:
optional.append(attr)
for acs in tssd.assertion_consumer_service:
self._loc_key[acs.location] = certs
if required or optional:
#print "REQ",required
#print "OPT",optional
self._wants[entity_descriptor.entity_id] = (ava_fro(self.attrconv,
required),
ava_fro(self.attrconv,
optional))
if ssds:
entity[tag] = ssds
@@ -136,7 +159,7 @@ class MetaData(object):
certs = [make_temp(c, suffix=".der") for c in certs]
for sso in tidp.single_sign_on_service:
self._loc_key[sso.location] = certs
if idps:
entity[tag] = idps
@@ -153,7 +176,7 @@ class MetaData(object):
except AttributeError:
#print "No Attribute AD: %s" % entity_descriptor.entity_id
return
aads = []
for taad in attr_auth_descr:
# Remove everyone that doesn't talk SAML 2.0
@@ -168,10 +191,10 @@ class MetaData(object):
#print "binding", attr_serv.binding
if attr_serv.binding == BINDING_SOAP:
aserv.append(attr_serv)
if aserv == []:
continue
taad.attribute_service = aserv
# gather all the certs and place them in temporary files
@@ -185,12 +208,12 @@ class MetaData(object):
self._loc_key[sso.location].append(certs)
except KeyError:
self._loc_key[sso.location] = certs
aads.append(taad)
if aads != []:
entity[tag] = aads
aads.append(taad)
if aads != []:
entity[tag] = aads
def clear_from_source(self, source):
for eid in self._import[source]:
del self.entity[eid]
@@ -202,28 +225,28 @@ class MetaData(object):
return
self.clear_from_source(source)
if isinstance(source, basestring):
f = open(source)
self.import_metadata( f.read(), source)
f.close()
else:
self.import_external_metadata(source[0],source[1])
fil = open(source)
self.import_metadata( fil.read(), source)
fil.close()
else:
self.import_external_metadata(source[0], source[1])
def import_metadata(self, xml_str, source):
""" Import information; organization distinguish name, location and
certificates from a metadata file.
:param xml_str: The metadata as a XML string.
"""
# now = time.gmtime()
entities_descriptor = md.entities_descriptor_from_string(xml_str)
try:
valid(entities_descriptor.valid_until)
except AttributeError:
pass
for entity_descriptor in entities_descriptor.entity_descriptor:
try:
if not valid(entity_descriptor.valid_until):
@@ -238,7 +261,7 @@ class MetaData(object):
continue
except AttributeError:
pass
try:
self._import[source].append(entity_descriptor.entity_id)
except KeyError:
@@ -248,7 +271,7 @@ class MetaData(object):
entity["valid_until"] = entities_descriptor.valid_until
self._idp_metadata(entity_descriptor, entity, "idp_sso")
self._sp_metadata(entity_descriptor, entity, "sp_sso")
self._aad_metadata(entity_descriptor, entity,
self._aad_metadata(entity_descriptor, entity,
"attribute_authority")
self._vo_metadata(entity_descriptor, entity, "affiliation")
try:
@@ -259,7 +282,7 @@ class MetaData(object):
entity["contact"] = entity_descriptor.contact
except AttributeError:
pass
def import_external_metadata(self, url, cert=None):
""" Imports metadata by the use of HTTP GET.
If the fingerprint is known the file will be checked for
@@ -274,28 +297,28 @@ class MetaData(object):
if verify_signature(content, self.xmlsec_binary, cert, "pem",
"%s:%s" % (md.EntitiesDescriptor.c_namespace,
md.EntitiesDescriptor.c_tag)):
self.import_metadata(content, (url,cert))
self.import_metadata(content, (url, cert))
return True
else:
self.log and self.log.info("Response status: %s" % response.status)
return False
@keep_updated
def single_sign_on_services(self, entity_id,
def single_sign_on_services(self, entity_id,
binding = BINDING_HTTP_REDIRECT):
""" Get me all single-sign-on services that supports the specified
binding version.
:param entity_id: The EntityId
:param binding: A binding identifier
:return: list of single-sign-on service location run by the entity
:return: list of single-sign-on service location run by the entity
with the specified EntityId.
"""
# May raise KeyError
idps = self.entity[entity_id]["idp_sso"]
loc = []
#print idps
for idp in idps:
@@ -305,21 +328,21 @@ class MetaData(object):
if binding == sso.binding:
loc.append(sso.location)
return loc
@keep_updated
def attribute_services(self, entity_id):
try:
return self.entity[entity_id]["attribute_authority"]
except KeyError:
return []
def locations(self):
""" Returns all the locations that are know using this metadata file.
:return: A list of IdP locations
"""
return self._loc_key.keys()
def certs(self, loc):
""" Get all certificates that are used by a IdP at the specified
location. There can be more than one because of overlapping lifetimes
@@ -340,14 +363,14 @@ class MetaData(object):
return self.entity[entity_id]["affiliation"]
except KeyError:
return []
@keep_updated
def consumer_url(self, entity_id, binding=BINDING_HTTP_POST, _log=None):
try:
ssos = self.entity[entity_id]["sp_sso"]
except KeyError:
raise
# any default ?
for sso in ssos:
for acs in sso.assertion_consumer_service:
@@ -360,7 +383,7 @@ class MetaData(object):
return acs.location
return None
@keep_updated
def name(self, entity_id):
""" Find a name from the metadata about this entity id.
@@ -368,7 +391,7 @@ class MetaData(object):
,in that order, for the organization.
:param entityid: The Entity ID
:return: A name
:return: A name
"""
try:
org = self.entity[entity_id]["organization"]
@@ -386,16 +409,23 @@ class MetaData(object):
name = names[0].text
except KeyError:
name = ""
return name
@keep_updated
def wants(self, entity_id):
try:
return self._wants[entity_id]
except KeyError:
return ([], [])
@keep_updated
def attribute_consumer(self, entity_id):
try:
ssos = self.entity[entity_id]["sp_sso"]
except KeyError:
return ([], [])
required = []
optional = []
# What if there is more than one ? Can't be ?
@@ -405,41 +435,41 @@ class MetaData(object):
required.append(attr)
else:
optional.append(attr)
return (required, optional)
return (required, optional)
def _orgname(self, org, lang="en"):
if not org:
return ""
for ll in [lang,None]:
for spec in [lang, None]:
for name in org.organization_display_name:
if name.lang == ll:
if name.lang == spec:
return name.text.strip()
for name in org.organization_name:
if name.lang == ll:
if name.lang == spec:
return name.text.strip()
for name in org.organization_url:
if name.lang == ll:
if name.lang == spec:
return name.text.strip()
return ""
def _location(self, idpsso):
loc = []
for idp in idpsso:
for sso in idp.single_sign_on_service:
loc.append(sso.location)
return loc
@keep_updated
def _valid(self, entity_id):
return True
# @keep_updated
# def _valid(self, entity_id):
# return True
def idps(self):
idps = {}
for entity_id, edict in self.entity.items():
if "idp_sso" in edict:
self._valid(entity_id)
#idp_aa_check self._valid(entity_id)
if "organization" in edict:
name = self._orgname(edict["organization"],"en")
if not name:

View File

@@ -22,19 +22,18 @@ or attribute authority (AA) may use to conclude its tasks.
import shelve
import sys
from saml2 import saml, samlp, VERSION, make_instance
from saml2 import saml, samlp, VERSION, class_name
from saml2.utils import sid, decode_base64_and_inflate
from saml2.utils import response_factory
from saml2.utils import MissingValue, args2dict
from saml2.utils import success_status_factory, assertion_factory
from saml2.utils import OtherError, do_attribute_statement
from saml2.utils import success_status_factory
from saml2.utils import OtherError
from saml2.utils import VersionMismatch, UnknownPrincipal, UnsupportedBinding
from saml2.utils import status_from_exception_factory
from saml2.sigver import security_context, signed_instance_factory
from saml2.sigver import pre_signature_part
from saml2.time_util import instant, in_a_while
from saml2.config import Config
from saml2.cache import Cache
from saml2.assertion import Assertion, Policy
@@ -45,7 +44,7 @@ class UnknownVO(Exception):
class Identifier(object):
""" A class that handles identifiers of objects """
def __init__(self, dbname, entityid, voconf=None, debug=0, log=None):
self.map = shelve.open(dbname,writeback=True)
self.map = shelve.open(dbname, writeback=True)
self.entityid = entityid
self.voconf = voconf
self.debug = debug
@@ -102,11 +101,11 @@ class Identifier(object):
raise UnknownVO("%s" % sp_name_qualifier)
try:
format = vo_conf["nameid_format"]
nameid_format = vo_conf["nameid_format"]
except KeyError:
format = saml.NAMEID_FORMAT_PERSISTENT
nameid_format = saml.NAMEID_FORMAT_PERSISTENT
return args2dict(subj_id, format=format,
return args2dict(subj_id, format=nameid_format,
sp_name_qualifier=sp_name_qualifier)
def persistent_nameid(self, sp_name_qualifier, userid):
@@ -156,13 +155,14 @@ class Server(object):
self.log = log
self.debug = debug
self.ident = None
if config_file:
self.load_config(config_file)
elif config:
self.conf = config
self.metadata = self.conf["metadata"]
self.sc = security_context(self.conf, log)
self.sec = security_context(self.conf, log)
if cache:
self.cache = Cache(cache)
else:
@@ -173,11 +173,11 @@ class Server(object):
self.conf = Config()
self.conf.load_file(config_file)
if "subject_data" in self.conf:
self.id = Identifier(self.conf["subject_data"],
self.ident = Identifier(self.conf["subject_data"],
self.conf["entityid"], self.conf.vo_conf,
self.debug, self.log)
else:
self.id = None
self.ident = None
def issuer(self):
""" Return an Issuer precursor """
@@ -199,7 +199,7 @@ class Server(object):
response = {}
request_xml = decode_base64_and_inflate(enc_request)
try:
request = self.sc.correctly_signed_authn_request(request_xml)
request = self.sec.correctly_signed_authn_request(request_xml)
if self.log and self.debug:
self.log.info("Request was correctly signed")
except Exception:
@@ -324,7 +324,7 @@ class Server(object):
if sign:
assertion["signature"] = pre_signature_part(assertion["id"],
self.sc.my_cert, 1)
self.sec.my_cert, 1)
# Store which assertion that has been sent to which SP about which
# subject.
@@ -335,7 +335,7 @@ class Server(object):
response.update({"assertion":assertion})
return signed_instance_factory(samlp.Response, response, self.sc)
return signed_instance_factory(samlp.Response, response, self.sec)
# ------------------------------------------------------------------------
@@ -364,13 +364,12 @@ class Server(object):
# ------------------------------------------------------------------------
def do_aa_response(self, consumer_url, in_response_to, sp_entity_id,
identity=None, userid="", name_id=None, ip_address="",
issuer=None, status=None, sign=False,
name_id_policy=None):
identity=None, userid="", name_id=None, status=None,
sign=False, name_id_policy=None):
name_id = self.id.construct_nameid(self.conf.aa_policy(), userid,
name_id = self.ident.construct_nameid(self.conf.aa_policy(), userid,
sp_entity_id, identity)
return self._response(consumer_url, in_response_to,
sp_entity_id, identity, name_id,
status, sign, policy=self.conf.aa_policy())
@@ -395,7 +394,7 @@ class Server(object):
"""
try:
name_id = self.id.construct_nameid(self.conf.idp_policy(),
name_id = self.ident.construct_nameid(self.conf.idp_policy(),
userid, sp_entity_id, identity,
name_id_policy)
except IOError, exc:
@@ -419,7 +418,7 @@ class Server(object):
if sign:
try:
return self.sc.sign_statement_using_xmlsec(response,
return self.sec.sign_statement_using_xmlsec(response,
class_name(response))
except Exception, exc:
response = self.error_response(destination, in_response_to,

View File

@@ -230,3 +230,10 @@ def valid( valid_until ):
return True
else:
return False
def later_than(then, that):
then = str_to_time( then )
then = str_to_time( that )
return then >= that

View File

@@ -2,8 +2,7 @@
import time
import base64
import re
from saml2 import samlp, saml, VERSION, sigver, NAME_FORMAT_URI
from saml2 import samlp, VERSION, sigver
from saml2.time_util import instant
try:
@@ -191,7 +190,7 @@ def response_factory(signature=False, encrypt=False, **kwargs):
return args2dict(**kwargs)
def _attrval(val):
if isinstance(val, list) or isinstance(val,set):
if isinstance(val, list) or isinstance(val, set):
attrval = [args2dict(v) for v in val]
elif val == None:
attrval = None