From f28754df355224910dc80b626e28819e602cdcc4 Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Sat, 17 Apr 2010 20:12:43 +0200 Subject: [PATCH] pylinted --- src/saml2/__init__.py | 72 +++++++++++-- src/saml2/assertion.py | 123 +++++++++++++--------- src/saml2/attribute_converter.py | 124 ++++++++++++++++------ src/saml2/authnresponse.py | 122 ++++++++++++---------- src/saml2/client.py | 170 ++++++++++++++++--------------- src/saml2/config.py | 44 ++++---- src/saml2/metadata.py | 154 +++++++++++++++++----------- src/saml2/server.py | 41 ++++---- src/saml2/time_util.py | 7 ++ src/saml2/utils.py | 5 +- 10 files changed, 528 insertions(+), 334 deletions(-) diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index ab3d839..d6fedaa 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -376,7 +376,7 @@ def make_vals(val, klass, klass_inst=None, prop=None, part=False, else: try: cinst = klass().set_text(val) - except ValueError, excp: + except ValueError: if not part: cis = [make_vals(sval, klass, klass_inst, prop, True, base64encode) for sval in val] @@ -482,11 +482,11 @@ class SamlBase(ExtensionContainer): def become_child_element_of(self, tree): """ - Note: Only for use with classes that have a c_tag and c_namespace class member. It is in SamlBase so that it can be inherited but it should not be called on instances of SamlBase. + :param tree: The tree to which this instance should be a child """ new_child = self._to_element_tree() tree.append(new_child) @@ -519,15 +519,28 @@ class SamlBase(ExtensionContainer): self.__dict__[extension_attribute_name] = value def keyswv(self): + """ Return the keys of attributes or children that has values + + :return: list of keys + """ return [key for key, val in self.__dict__.items() if val] def keys(self): + """ Return all the keys that represent possible attributes and + children. + + :return: list of keys + """ keys = ['text'] keys.extend(self.c_attributes.values()) keys.extend([v[1] for v in self.c_children.values()]) return keys def children_with_values(self): + """ Returns all children that has values + + :return: Possibly empty list of children. + """ childs = [] for _, values in self.__class__.c_children.iteritems(): member = getattr(self, values[0]) @@ -541,7 +554,13 @@ class SamlBase(ExtensionContainer): return childs def set_text(self, val, base64encode=False): - """ """ + """ Sets the text property of this instance. + + :param val: The value of the text property + :param base64encode: Whether the value should be base64encoded + :return: The instance + """ + #print "set_text: %s" % (val,) if isinstance(val, bool): if val: @@ -560,7 +579,19 @@ class SamlBase(ExtensionContainer): return self def loadd(self, ava, base64encode=False): - """ """ + """ + Sets attributes, children, extension elements and extension + attributes of this element instance depending on what is in + the given dictionary. If there are already values on properties + those will be overwritten. If the keys in the dictionary does + not correspond to known attributes/children/.. they are ignored. + + :param ava: The dictionary + :param base64encode: Whether the values on attributes or texts on + children shoule be base64encoded. + :return: The instance + """ + for prop in self.c_attributes.values(): #print "# %s" % (prop) if prop in ava: @@ -578,7 +609,8 @@ class SamlBase(ExtensionContainer): #print "## %s, %s" % (prop, klassdef) if prop in ava: #print "### %s" % ava[prop] - if isinstance(klassdef, list): # means there can be a list of values + # means there can be a list of values + if isinstance(klassdef, list): make_vals(ava[prop], klassdef[0], self, prop, base64encode=base64encode) else: @@ -588,7 +620,8 @@ class SamlBase(ExtensionContainer): if "extension_elements" in ava: for item in ava["extension_elements"]: - self.extension_elements.append(ExtensionElement(item["tag"]).loadd(item)) + self.extension_elements.append(ExtensionElement( + item["tag"]).loadd(item)) if "extension_attributes" in ava: for key, val in ava["extension_attributes"].items(): @@ -598,22 +631,39 @@ class SamlBase(ExtensionContainer): def element_to_extension_element(element): - ee = ExtensionElement(element.c_tag, element.c_namespace, + """ + Convert an element into a extension element + + :param element: The element instance + :return: An extension element instance + """ + + exel = ExtensionElement(element.c_tag, element.c_namespace, text=element.text) for xml_attribute, member_name in element.c_attributes.iteritems(): member_value = getattr(element, member_name) if member_value is not None: - ee.attributes[xml_attribute] = member_value + exel.attributes[xml_attribute] = member_value - ee.children = [element_to_extension_element(c) \ + exel.children = [element_to_extension_element(c) \ for c in element.children_with_values()] - return ee + return exel def extension_element_to_element(extension_element, translation_functions, namespace=None): - """ """ + """ Convert an extension element to a normal element. + In order to do this you need to have an idea of what type of + element it is. Or rather which module it belongs to. + + :param extension_element: The extension element + :prama translation_functions: A dictionary which klass identifiers + as keys and string-to-element translations functions as values + :param namespace: The namespace of the translation functions. + :return: An element instance or None + """ + try: element_namespace = extension_element.namespace except AttributeError: diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index 3e0c586..75d5923 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -35,10 +35,10 @@ def _filter_values(vals, required=None, optional=None): :param optional: The optional values :return: The set of values after filtering """ - + if not required and not optional: return vals - + valr = [] valo = [] if required: @@ -54,17 +54,17 @@ def _filter_values(vals, required=None, optional=None): valr.append(val) elif val in ovals: valo.append(val) - + valo.extend(valr) if rvals: if len(rvals) == len(valr): return valo else: - a = set(rvals) + _ = set(rvals) raise MissingValue("Required attribute value missing") else: return valo - + def _combine(required=None, optional=None): res = {} if not required: @@ -81,12 +81,12 @@ def _combine(required=None, optional=None): res[(attr.name, attr.friendly_name)] = part else: res[(attr.name, attr.friendly_name)] = (attr.attribute_value, []) - + for oat in optional: tag = (oat.name, oat.friendly_name) if tag not in res: res[tag] = ([], oat.attribute_value) - + return res def filter_on_attributes(ava, required=None, optional=None): @@ -102,24 +102,48 @@ def filter_on_attributes(ava, required=None, optional=None): res[attr[1]] = _filter_values(ava[attr[1]], vals[0], vals[1]) else: print >> sys.stderr, ava.keys() - raise MissingValue("Required attribute missing: '%s'" % (attr,)) + raise MissingValue("Required attribute missing: '%s'" % (attr[0],)) return res +def filter_on_demands(ava, required={}, optional={}): + """ Never return more than is needed """ + + # Is all what's required there: + + for attr, vals in required.items(): + if attr in ava: + if vals: + for val in vals: + if val not in ava[attr]: + raise MissingValue( + "Required attribute value missing: %s,%s" % (attr, + val)) + else: + raise MissingValue("Required attribute missing: %s" % (attr,)) + + # OK, so I can imaging releasing values that are not absolutely necessary + # but not attributes + for attr, vals in ava.items(): + if attr not in required and attr not in optional: + del ava[attr] + + return ava + def filter_attribute_value_assertions(ava, attribute_restrictions=None): - """ Will weed out attribute values and values according to the - rules defined in the attribute restrictions. If filtering results in + """ Will weed out attribute values and values according to the + rules defined in the attribute restrictions. If filtering results in an attribute without values, then the attribute is removed from the assertion. - :param ava: The incoming attribute value assertion + :param ava: The incoming attribute value assertion (dictionary) :param attribute_restrictions: The rules that govern which attributes - and values that are allowed. + and values that are allowed. (dictionary) :return: The modified attribute value assertion """ if not attribute_restrictions: return ava - + for attr, vals in ava.items(): if attr in attribute_restrictions: if attribute_restrictions[attr]: @@ -136,7 +160,7 @@ def filter_attribute_value_assertions(ava, attribute_restrictions=None): else: del ava[attr] return ava - + class Policy(object): """ handles restrictions on assertions """ @@ -145,11 +169,11 @@ class Policy(object): self.compile(restrictions) else: self._restrictions = None - + def compile(self, restrictions): """ This is only for IdPs or AAs, and it's about limiting what - is returned to the SP. - In the configuration file, restrictions on which values that + is returned to the SP. + In the configuration file, restrictions on which values that can be returned are specified with the help of regular expressions. This function goes through and pre-compiles the regular expressions. @@ -163,25 +187,25 @@ class Policy(object): for _, spec in self._restrictions.items(): if spec == None: continue - + try: restr = spec["attribute_restrictions"] except KeyError: continue - + if restr == None: continue - + for key, values in restr.items(): if not values: spec["attribute_restrictions"][key] = None continue - + spec["attribute_restrictions"][key] = \ [re.compile(value) for value in values] return self._restrictions - + def get_nameid_format(self, sp_entity_id): try: form = self._restrictions[sp_entity_id]["nameid_format"] @@ -190,7 +214,7 @@ class Policy(object): form = self._restrictions["default"]["nameid_format"] except KeyError: form = saml.NAMEID_FORMAT_TRANSIENT - + return form def get_name_form(self, sp_entity_id): @@ -203,7 +227,7 @@ class Policy(object): form = self._restrictions["default"]["name_form"] except KeyError: pass - + return form def get_lifetime(self, sp_entity_id): @@ -211,7 +235,7 @@ class Policy(object): spec = {"hours":1} if not self._restrictions: return spec - + try: spec = self._restrictions[sp_entity_id]["lifetime"] except KeyError: @@ -219,13 +243,13 @@ class Policy(object): spec = self._restrictions["default"]["lifetime"] except KeyError: pass - - return spec + return spec + def get_attribute_restriction(self, sp_entity_id): if not self._restrictions: return None - + try: try: restrictions = self._restrictions[sp_entity_id][ @@ -238,9 +262,9 @@ class Policy(object): restrictions = None except KeyError: restrictions = None - - return restrictions + return restrictions + def _not_on_or_after(self, sp_entity_id): """ When the assertion stops being valid, should not be used after this time. @@ -249,13 +273,13 @@ class Policy(object): """ return in_a_while(**self.get_lifetime(sp_entity_id)) - + def filter(self, ava, sp_entity_id, required=None, optional=None): """ What attribute and attribute values returns depends on what the SP has said it wants in the request or in the metadata file and what the IdP/AA wants to release. An assumption is that what the SP - asks for overrides whatever is in the metadata. But of course the - IdP never releases anything it doesn't want to. + asks for overrides whatever is in the metadata. But of course the + IdP never releases anything it doesn't want to. :param ava: The information about the subject as a dictionary :param sp_entity_id: The entity ID of the SP @@ -263,18 +287,18 @@ class Policy(object): :param optional: Attributes that the SP thinks is optional :return: A possibly modified AVA """ - - ava = filter_attribute_value_assertions(ava, - self.get_attribute_restriction(sp_entity_id)) - + + ava = filter_attribute_value_assertions(ava, + self.get_attribute_restriction(sp_entity_id)) + if required or optional: ava = filter_on_attributes(ava, required, optional) - + return ava - + def restrict(self, ava, sp_entity_id, metadata=None): - """ Identity attribute names are expected to be expressed in + """ Identity attribute names are expected to be expressed in the local lingo (== friendlyName) :return: A filtered ava according to the IdPs/AAs rules and @@ -283,16 +307,17 @@ class Policy(object): """ if metadata: (required, optional) = metadata.attribute_consumer(sp_entity_id) + #(required, optional) = metadata.wants(sp_entity_id) else: required = optional = None - + return self.filter(ava, sp_entity_id, required, optional) def conditions(self, sp_entity_id): return args2dict( - not_before=instant(), + not_before=instant(), # How long might depend on who's getting it - not_on_or_after=self._not_on_or_after(sp_entity_id), + not_on_or_after=self._not_on_or_after(sp_entity_id), audience_restriction=args2dict( audience=args2dict(sp_entity_id))) @@ -301,16 +326,16 @@ class Assertion(dict): def __init__(self, dic=None): dict.__init__(self, dic) - + def _authn_statement(self): return args2dict(authn_instant=instant(), session_index=sid()) - - def construct(self, sp_entity_id, in_response_to, name_id, attrconvs, - policy, issuer): + def construct(self, sp_entity_id, in_response_to, name_id, attrconvs, + policy, issuer): + attr_statement = from_local(attrconvs, self, policy.get_name_form(sp_entity_id)) - + # start using now and for a hour conds = policy.conditions(sp_entity_id) @@ -326,6 +351,6 @@ class Assertion(dict): subject_confirmation_data = \ args2dict(in_response_to=in_response_to))), ) - + def apply_policy(self, sp_entity_id, policy, metadata=None): return policy.restrict(self, sp_entity_id, metadata) diff --git a/src/saml2/attribute_converter.py b/src/saml2/attribute_converter.py index e136514..aece875 100644 --- a/src/saml2/attribute_converter.py +++ b/src/saml2/attribute_converter.py @@ -17,6 +17,7 @@ import os from saml2.utils import args2dict +from saml2.saml import NAME_FORMAT_URI class UnknownNameFormat(Exception): pass @@ -26,44 +27,82 @@ def ac_factory(path): for tup in os.walk(path): if tup[2]: - ac = AttributeConverter(os.path.basename(tup[0])) + atco = AttributeConverter(os.path.basename(tup[0])) for name in tup[2]: fname = os.path.join(tup[0], name) if name.endswith(".py"): name = name[:-3] - ac.set(name,fname) - ac.adjust() - acs.append(ac) + atco.set(name, fname) + atco.adjust() + acs.append(atco) return acs +def ava_fro(acs, statement): + """ translates attributes according to their name_formats """ + if statement == []: + return {} + + acsdic = dict([(ac.format, ac) for ac in acs]) + acsdic[None] = acsdic[NAME_FORMAT_URI] + return dict([acsdic[a.name_format].ava_fro(a) for a in statement]) + def to_local(acs, statement): if not acs: acs = [AttributeConverter()] ava = [] - for ac in acs: + for aconv in acs: try: - ava = ac.fro(statement) + ava = aconv.fro(statement) break except UnknownNameFormat: pass return ava def from_local(acs, ava, name_format): - for ac in acs: + for aconv in acs: #print ac.format, name_format - if ac.format == name_format: + if aconv.format == name_format: #print "Found a name_form converter" - return ac.to(ava) + return aconv.to(ava) return None +def from_local_name(acs, attr, name_format): + """ + :param acs: List of AttributeConverter instances + :param attr: attribute name as string + :param name_format: Which name-format it should be translated to + :return: A dictionary suitable to feed to make_instance + """ + for aconv in acs: + #print ac.format, name_format + if aconv.format == name_format: + #print "Found a name_form converter" + return aconv.to_format(attr) + return attr + +def to_local_name(acs, attr): + """ + :param acs: List of AttributeConverter instances + :param attr: an Attribute instance + :return: The local attribute name + """ + for aconv in acs: + lattr = aconv.from_format(attr) + if lattr: + return lattr + + return attr.friendly_name + class AttributeConverter(object): """ Converts from an attribute statement to a key,value dictionary and vice-versa """ - def __init__(self, format=""): - self.format = format + def __init__(self, name_format=""): + self.name_format = name_format + self._to = None + self._fro = None def set(self, name, filename): if name == "to": @@ -101,6 +140,24 @@ class AttributeConverter(object): result[name].append(value.text.strip()) return result + def ava_fro(self, attribute): + try: + attr = self._fro[attribute.name.strip()] + except (AttributeError, KeyError): + try: + attr = attribute.friendly_name.strip() + except AttributeError: + attr = attribute.name.strip() + + val = [] + for value in attribute.attribute_value: + if not value.text: + val.append('') + else: + val.append(value.text.strip()) + + return (attr, val) + def fro(self, statement): """ Get the attributes and the attribute values @@ -108,41 +165,48 @@ class AttributeConverter(object): :return: A dictionary containing attributes and values """ - if not self.format: + if not self.name_format: return self.fail_safe_fro(statement) result = {} for attribute in statement.attribute: - if attribute.name_format and self.format and \ - attribute.name_format != self.format: + if attribute.name_format and self.name_format and \ + attribute.name_format != self.name_format: raise UnknownNameFormat - try: - name = self._fro[attribute.name.strip()] - except (AttributeError, KeyError): - try: - name = attribute.friendly_name.strip() - except AttributeError: - name = attribute.name.strip() - - result[name] = [] - for value in attribute.attribute_value: - if not value.text: - result[name].append('') - else: - result[name].append(value.text.strip()) - + (key, val) = self.ava_fro(attribute) + result[key] = val + if not result: return self.fail_safe_fro(statement) else: return result + def to_format(self, attr): + try: + return args2dict(name=self._to[attr], name_format=self.name_format, + friendly_name=attr) + except KeyError: + return args2dict(name=attr) + + def from_format(self, attr): + """ + :param attr: An saml.Attribute instance + :return: The local attribute name or "" if no mapping could be made + """ + if self.name_format == attr.name_format: + try: + return self._fro[attr.name] + except KeyError: + pass + return "" + def to(self, ava): attributes = [] for key, value in ava.items(): try: attributes.append(args2dict(name=self._to[key], - name_format=self.format, + name_format=self.name_format, friendly_name=key, attribute_value=value)) except KeyError: diff --git a/src/saml2/authnresponse.py b/src/saml2/authnresponse.py index 69629d9..d4b73b3 100644 --- a/src/saml2/authnresponse.py +++ b/src/saml2/authnresponse.py @@ -20,7 +20,10 @@ import time from saml2.time_util import str_to_time -from saml2 import samlp, saml +from saml2 import samlp +from saml2 import saml +from saml2 import extension_element_to_element + from saml2.sigver import security_context from saml2.attribute_converter import to_local @@ -46,7 +49,7 @@ def _use_before(condition, slack): #print "NOW: %s" % now not_before = time.mktime(str_to_time(condition.not_before)) #print "not_before: %d" % not_before - + if not_before > now + slack: # Can't use it yet raise Exception("Can't use it yet %s <= %s" % (not_before, now)) @@ -63,15 +66,15 @@ def for_me(condition, myself ): pass return False - + # --------------------------------------------------------------------------- class IncorrectlySigned(Exception): pass - + # --------------------------------------------------------------------------- - -def authn_response(conf, requestor, outstanding_queries=None, log=None, + +def authn_response(conf, requestor, outstanding_queries=None, log=None, timeslack=0, debug=0): sec = security_context(conf) if not timeslack: @@ -79,15 +82,15 @@ def authn_response(conf, requestor, outstanding_queries=None, log=None, timeslack = int(conf["timeslack"]) except KeyError: pass - - return AuthnResponse(sec, conf.attribute_converters, requestor, - outstanding_queries, log, timeslack, debug) + return AuthnResponse(sec, conf.attribute_converters, requestor, + outstanding_queries, log, timeslack, debug) + class AuthnResponse(object): - def __init__(self, security_context, attribute_converters, requestor, + def __init__(self, sec_context, attribute_converters, requestor, outstanding_queries=None, log=None, timeslack=0, debug=0): - self.sc = security_context + self.sec = sec_context self.attribute_converters = attribute_converters self.requestor = requestor if outstanding_queries: @@ -100,9 +103,15 @@ class AuthnResponse(object): self.debug = debug if self.debug and not self.log: self.debug = 0 - self.clear() - self.xmlstr = "" + self.xmlstr = "" + self.came_from = "" + self.name_id = "" + self.ava = None + self.response = None + self.not_on_or_after = 0 + self.assertion = None + def loads(self, xmldata, decode=True): if self.debug: self.log.info("--- Loads AuthnResponse ---") @@ -112,11 +121,11 @@ class AuthnResponse(object): decoded_xml = xmldata # own copy - self.xmlstr = decoded_xml[:] + self.xmlstr = decoded_xml[:] if self.debug: self.log.info("xmlstr: %s" % (self.xmlstr,)) try: - self.response = self.sc.correctly_signed_response(decoded_xml) + self.response = self.sec.correctly_signed_response(decoded_xml) except Exception, excp: self.log and self.log.info("EXCEPTION: %s", excp) raise @@ -126,12 +135,12 @@ class AuthnResponse(object): self.log.error("Response was not correctly signed") self.log.info(decoded_xml) raise IncorrectlySigned() - + if self.debug: self.log.info("response: %s" % (self.response,)) - - return self + return self + def clear(self): self.xmlstr = "" self.came_from = "" @@ -140,7 +149,7 @@ class AuthnResponse(object): self.response = None self.not_on_or_after = 0 self.assertion = None - + def status_ok(self): if self.response.status: status = self.response.status @@ -152,13 +161,13 @@ class AuthnResponse(object): raise Exception( "Not successfull according to: %s" % \ status.status_code.value) - + def authn_statement_ok(self): # the assertion MUST contain one AuthNStatement assert len(self.assertion.authn_statement) == 1 # authn_statement = assertion.authn_statement[0] # check authn_statement.session_index - + def condition_ok(self, lax=False): # The Identity Provider MUST include a element #print "Conditions",assertion.conditions @@ -166,23 +175,23 @@ class AuthnResponse(object): condition = self.assertion.conditions if self.debug: self.log.info("condition: %s" % condition) - + try: self.not_on_or_after = _use_on_or_after(condition, self.timeslack) _use_before(condition, self.timeslack) - except Exception,excp: + except Exception, excp: self.log.error("Exception on condition: %s" % (excp,)) if not lax: raise else: self.not_on_or_after = 0 - + if not for_me(condition, self.requestor): if not lax: raise Exception("Not for me!!!") return True - + def get_identity(self): # The assertion can contain zero or one attributeStatements if not self.assertion.attribute_statement: @@ -194,13 +203,13 @@ class AuthnResponse(object): if self.debug: self.log.info("Attribute Statement: %s" % ( self.assertion.attribute_statement[0],)) - for ac in self.attribute_converters(): - self.log.info("Converts name format: %s" % (ac.format,)) - + for aconv in self.attribute_converters(): + self.log.info("Converts name format: %s" % (aconv.format,)) + ava = to_local(self.attribute_converters(), self.assertion.attribute_statement[0]) return ava - + def get_subject(self): # The assertion must contain a Subject assert self.assertion.subject @@ -221,7 +230,7 @@ class AuthnResponse(object): # The subject must contain a name_id assert subject.name_id self.name_id = subject.name_id.text.strip() - + def _assertion(self, assertion): self.assertion = assertion @@ -233,24 +242,24 @@ class AuthnResponse(object): if self.context == "AuthNReq": self.authn_statement_ok() - + if not self.condition_ok(): return None - + if self.debug: self.log.info("--- Getting Identity ---") - + self.ava = self.get_identity() if self.debug: self.log.info("--- AVA: %s" % (self.ava,)) - + self.get_subject() - + return True - + def _encrypted_assertion(self, xmlstr): - decrypt_xml = self.sc.decrypt(self.xmlstr) + decrypt_xml = self.sec.decrypt(xmlstr) if self.debug: self.log.info("Decryption successfull") @@ -259,57 +268,66 @@ class AuthnResponse(object): if self.debug: self.log.info("Parsed decrypted assertion successfull") - enc = self.response.encrypted_assertion[0].extension_elements[0] - assertion = extension_element_to_element(enc, + enc = self.response.encrypted_assertion[0].extension_elements[0] + assertion = extension_element_to_element(enc, saml.ELEMENT_FROM_STRING, namespace=saml.NAMESPACE) if self.debug: self.log.info("Decrypted Assertion: %s" % assertion) return self._assertion(assertion) - + def parse_assertion(self): try: assert len(self.response.assertion) == 1 or \ len(self.response.encrypted_assertion) == 1 except AssertionError: raise Exception("No assertion part") - - if self.response.assertion: + + if self.response.assertion: self.debug and self.log.info("***Unencrypted response***") return self._assertion(self.response.assertion[0]) else: self.debug and self.log.info("***Encrypted response***") return self._encrypted_assertion( self.response.encrypted_assertion[0]) - - return True + return True + def verify(self): - """ """ + """ Verify that the assertion is syntaktically correct and + the signature is correct if present.""" + self.status_ok() if self.parse_assertion(): return self else: return None - + def issuer(self): + """ Return the issuer of the reponse """ return self.response.issuer.text - + def session_id(self): + """ Returns the SessionID of the response """ return self.response.in_response_to - + def id(self): + """ Return the ID of the response """ return self.response.id - + def session_info(self): - return { "ava": self.ava, "name_id": self.name_id, + """ Returns a predefined set of information gleened from the + response. + :returns: Dictionary with information + """ + return { "ava": self.ava, "name_id": self.name_id, "came_from": self.came_from, "issuer": self.issuer(), "not_on_or_after": self.not_on_or_after } def __str__(self): return "%s" % self.xmlstr - + # ====================================================================== - + # session_info["ava"]["__userid"] = session_info["name_id"] # return session_info diff --git a/src/saml2/client.py b/src/saml2/client.py index 3fda725..d943aba 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -23,19 +23,17 @@ import os import urllib import saml2 import base64 -import time -import sys -from saml2.time_util import str_to_time, instant + +from saml2.time_util import instant from saml2.utils import sid, deflate_and_base64_encode from saml2.utils import do_attributes, args2dict -from saml2 import samlp, saml, extension_element_to_element -from saml2 import VERSION, class_name, make_instance +from saml2 import samlp, saml +from saml2 import VERSION, make_instance from saml2.sigver import pre_signature_part from saml2.sigver import security_context, signed_instance_factory from saml2.soap import SOAPClient -from saml2.attribute_converter import to_local from saml2.authnresponse import authn_response DEFAULT_BINDING = saml2.BINDING_HTTP_REDIRECT @@ -61,51 +59,52 @@ class Saml2Client(object): self.config = config if "metadata" in config: self.metadata = config["metadata"] - self.sc = security_context(config) + self.sec = security_context(config) self.debug = debug - + def _init_request(self, request, destination): #request.id = sid() request.version = VERSION request.issue_instant = instant() request.destination = destination - return request - + return request + def idp_entry(self, name=None, location=None, provider_id=None): res = {} - if name: + if name: res["name"] = name - if location: + if location: res["loc"] = location - if provider_id: + if provider_id: res["provider_id"] = provider_id if res: return res else: return None - + def scoping(self, idp_ents): return { "idp_list": { "idp_entry": idp_ents } } - + def scoping_from_metadata(self, entityid, location): name = self.metadata.name(entityid) - return make_instance(self.scoping([self.idp_entry(name, location)])) - + return make_instance(samlp.Scoping, + self.scoping([self.idp_entry(name, location)])) + def response(self, post, requestor, outstanding, log=None): """ Deal with the AuthnResponse :param post: The reply as a dictionary :param requestor: The issuer of the AuthN request - :param outstanding: A dictionary with session IDs as keys and + :param outstanding: A dictionary with session IDs as keys and the original web request from the user before redirection as values. :param log: where loggin should go. - :return: A 2-tuple of identity information (in the form of a + :return: A 2-tuple of identity information (in the form of a dictionary) and where the user should really be sent. This might differ from what the IdP thinks since I don't want to reveal verything to it and it might not trust me. @@ -115,18 +114,18 @@ class Saml2Client(object): saml_response = post['SAMLResponse'] except KeyError: return None - + if saml_response: - ar = authn_response(self.config, requestor, outstanding, log, + aresp = authn_response(self.config, requestor, outstanding, log, debug=self.debug) - ar.loads(saml_response) + aresp.loads(saml_response) if self.debug: - log and log.info(ar) - return ar.verify() - + log and log.info(aresp) + return aresp.verify() + return None - - def authn_request(self, query_id, destination, service_url, spentityid, + + def authn_request(self, query_id, destination, service_url, spentityid, my_name, vorg="", scoping=None, log=None, sign=False): """ Creates an authentication request. @@ -152,7 +151,7 @@ class Saml2Client(object): if scoping: prel["scoping"] = scoping - + name_id_policy = { "allow_create": "true" } @@ -166,37 +165,37 @@ class Saml2Client(object): pass if sign: - prel["signature"] = pre_signature_part(prel["id"], - self.sc.my_cert, id=1) - + prel["signature"] = pre_signature_part(prel["id"], + self.sec.my_cert, id=1) + prel["name_id_policy"] = name_id_policy prel["issuer"] = { "text": spentityid } if log: log.info("DICT VERSION: %s" % prel) - - return "%s" % signed_instance_factory(samlp.AuthnRequest, prel, - self.sc) - def authenticate(self, spentityid, location="", service_url="", + return "%s" % signed_instance_factory(samlp.AuthnRequest, prel, + self.sec) + + def authenticate(self, spentityid, location="", service_url="", my_name="", relay_state="", binding=saml2.BINDING_HTTP_REDIRECT, log=None, vorg="", scoping=None): """ Sends an authentication request. :param spentityid: The SP EntityID - :param binding: How the authentication request should be sent to the + :param binding: How the authentication request should be sent to the IdP :param location: Where the IdP is. :param service_url: The SP's service URL :param my_name: The providers name - :param relay_state: To where the user should be returned after + :param relay_state: To where the user should be returned after successfull log in. :param binding: Which binding to use for sending the request :param log: Where to write log messages :param vorg: The entity_id of the virtual organization I'm a member of :param scoping: For which IdPs this query are aimed. - + :return: AuthnRequest response """ @@ -206,8 +205,8 @@ class Saml2Client(object): log.info("service_url: %s" % service_url) log.info("my_name: %s" % my_name) session_id = sid() - authen_req = self.authn_request(session_id, location, - service_url, spentityid, my_name, vorg, + authen_req = self.authn_request(session_id, location, + service_url, spentityid, my_name, vorg, scoping, log) log and log.info("AuthNReq: %s" % authen_req) @@ -238,32 +237,34 @@ class Saml2Client(object): else: raise Exception("Unkown binding type: %s" % binding) return (session_id, response) - - def create_attribute_query(self, session_id, subject_id, issuer, - destination, attribute=None, sp_name_qualifier=None, - name_qualifier=None, nameformat=None, sign=False): - """ Constructs an AttributeQuery + + def create_attribute_query(self, session_id, subject_id, issuer, + destination, attribute=None, sp_name_qualifier=None, + name_qualifier=None, nameid_format=None, sign=False): + """ Constructs an AttributeQuery :param subject_id: The identifier of the subject :param destination: To whom the query should be sent - :param attribute: A dictionary of attributes and values that is + :param attribute: A dictionary of attributes and values that is asked for. The key are one of 4 variants: 3-tuple of name_format,name and friendly_name, 2-tuple of name_format and name, - 1-tuple with name or + 1-tuple with name or just the name as a string. - :param sp_name_qualifier: The unique identifier of the - service provider or affiliation of providers for whom the + :param sp_name_qualifier: The unique identifier of the + service provider or affiliation of providers for whom the identifier was generated. - :param name_qualifier: The unique identifier of the identity + :param name_qualifier: The unique identifier of the identity provider that generated the identifier. + :param nameid_format: The format of the name ID + :param sign: Whether the query should be signed or not. :return: An AttributeQuery instance """ - + subject = args2dict( - name_id = args2dict(subject_id, format=nameformat, + name_id = args2dict(subject_id, format=nameid_format, sp_name_qualifier=sp_name_qualifier, name_qualifier=name_qualifier), ) @@ -278,45 +279,46 @@ class Saml2Client(object): } if sign: - prequery["signature"] = pre_signature_part(prequery["id"], - self.sc.my_cert, 1) + prequery["signature"] = pre_signature_part(prequery["id"], + self.sec.my_cert, 1) if attribute: prequery["attribute"] = do_attributes(attribute) - + request = make_instance(samlp.AttributeQuery, prequery) if sign: - signed_req = self.sc.sign_assertion_using_xmlsec("%s" % request) + signed_req = self.sec.sign_assertion_using_xmlsec("%s" % request) return samlp.attribute_query_from_string(signed_req) - + else: return request - - def attribute_query(self, subject_id, issuer, destination, - attribute=None, sp_name_qualifier=None, name_qualifier=None, - format=None, log=None): + + def attribute_query(self, subject_id, issuer, destination, + attribute=None, sp_name_qualifier=None, name_qualifier=None, + nameid_format=None, log=None): """ Does a attribute request from an attribute authority - + :param subject_id: The identifier of the subject :param destination: To whom the query should be sent :param attribute: A dictionary of attributes and values that is asked for - :param sp_name_qualifier: The unique identifier of the - service provider or affiliation of providers for whom the + :param sp_name_qualifier: The unique identifier of the + service provider or affiliation of providers for whom the identifier was generated. - :param name_qualifier: The unique identifier of the identity + :param name_qualifier: The unique identifier of the identity provider that generated the identifier. + :param nameid_format: The format of the name ID :return: The attributes returned """ session_id = sid() - request = self.create_attribute_query(session_id, subject_id, - issuer, destination, attribute, sp_name_qualifier, - name_qualifier, nameformat=format) + request = self.create_attribute_query(session_id, subject_id, + issuer, destination, attribute, sp_name_qualifier, + name_qualifier, nameid_format=nameid_format) log and log.info("Request, created: %s" % request) - soapclient = SOAPClient(destination, self.config["key_file"], + soapclient = SOAPClient(destination, self.config["key_file"], self.config["cert_file"]) log and log.info("SOAP client initiated") try: @@ -324,32 +326,32 @@ class Saml2Client(object): except Exception, exc: log and log.info("SoapClient exception: %s" % (exc,)) return None - + log and log.info("SOAP request sent and got response: %s" % response) if response: log and log.info("Verifying response") - ar = authn_response(self.config, issuer, {session_id:""}, log) - session_info = ar.loads(response).verify().session_info() + aresp = authn_response(self.config, issuer, {session_id:""}, log) + session_info = aresp.loads(response).verify().session_info() log and log.info("session: %s" % session_info) return session_info else: log and log.info("No response") return None - + def make_logout_request(self, session_id, destination, issuer, reason=None, not_on_or_after=None): """ Constructs a LogoutRequest - + :param subject_id: The identifier of the subject - :param reason: An indication of the reason for the logout, in the + :param reason: An indication of the reason for the logout, in the form of a URI reference. - :param not_on_or_after: The time at which the request expires, + :param not_on_or_after: The time at which the request expires, after which the recipient may discard the message. :return: An AttributeQuery instance """ - + prel = { "id": sid(), "version": VERSION, @@ -358,21 +360,21 @@ class Saml2Client(object): "issuer": issuer, "session_index": session_id, } - + if reason: prel["reason"] = reason - + if not_on_or_after: prel["not_on_or_after"] = not_on_or_after - - return make_instance(samlp.LogoutRequest, prel) + return make_instance(samlp.LogoutRequest, prel) + def logout(self, session_id, destination, - issuer, reason="", not_on_or_after=None): + issuer, reason="", not_on_or_after=None): return self.make_logout_request(session_id, destination, issuer, reason, not_on_or_after) - + # ---------------------------------------------------------------------- ROW = """%s%s""" @@ -396,7 +398,7 @@ def _print_statement(statem): txt.append(ROW % (key, _print_statement(val))) else: txt.append(ROW % (key, val)) - + txt.append("") return "\n".join(txt) diff --git a/src/saml2/config.py b/src/saml2/config.py index e711ff8..a1a5972 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -2,10 +2,9 @@ # -*- coding: utf-8 -*- # -from saml2 import metadata, utils, saml +from saml2 import metadata from saml2.assertion import Policy from saml2.attribute_converter import ac_factory, AttributeConverter -import re class MissingValue(Exception): pass @@ -22,24 +21,24 @@ def entity_id2url(meta, entity_id): return meta.single_sign_on_services(entity_id)[0] class Config(dict): - def sp_check(self, config, metadata=None): - """ config["idp"] is a dictionary with entity_ids as keys and - urls as values + def _sp_check(self, config, metadat=None): + """ Verify that the SP configuration part is correct. + """ - if metadata: + if metadat: if "idp" not in config or len(config["idp"]) == 0: - eids = [e for e, d in metadata.entity.items() if "idp_sso" in d] + eids = [e for e, d in metadat.entity.items() if "idp_sso" in d] config["idp"] = {} for eid in eids: try: - config["idp"][eid] = entity_id2url(metadata, eid) - except IndexError, KeyError: + config["idp"][eid] = entity_id2url(metadat, eid) + except (IndexError, KeyError): if not config["idp"][eid]: raise MissingValue else: for eid, url in config["idp"].items(): if not url: - config["idp"][eid] = entity_id2url(metadata, eid) + config["idp"][eid] = entity_id2url(metadat, eid) else: assert "idp" in config assert len(config["idp"]) > 0 @@ -47,7 +46,7 @@ class Config(dict): assert "url" in config assert "name" in config - def idp_aa_check(self, config): + def _idp_aa_check(self, config): assert "url" in config if "assertions" in config: config["policy"] = Policy(config["assertions"]) @@ -55,9 +54,9 @@ class Config(dict): elif "policy" in config: config["policy"] = Policy(config["policy"]) - def load_metadata(self, metadata_conf, xmlsec_binary): + def load_metadata(self, metadata_conf, xmlsec_binary, acs): """ Loads metadata into an internal structure """ - metad = metadata.MetaData(xmlsec_binary) + metad = metadata.MetaData(xmlsec_binary, acs) if "local" in metadata_conf: for mdfile in metadata_conf["local"]: metad.import_metadata(open(mdfile).read(), mdfile) @@ -86,26 +85,27 @@ class Config(dict): else: config["key_file"] = None - if "metadata" in config: - config["metadata"] = self.load_metadata(config["metadata"], - config["xmlsec_binary"]) - if "attribute_map_dir" in config: config["attrconverters"] = ac_factory( config["attribute_map_dir"]) else: config["attrconverters"] = [AttributeConverter()] - + + if "metadata" in config: + config["metadata"] = self.load_metadata(config["metadata"], + config["xmlsec_binary"], + config["attrconverters"]) + if "sp" in config["service"]: #print config["service"]["sp"] if "metadata" in config: - self.sp_check(config["service"]["sp"], config["metadata"]) + self._sp_check(config["service"]["sp"], config["metadata"]) else: - self.sp_check(config["service"]["sp"]) + self._sp_check(config["service"]["sp"]) if "idp" in config["service"]: - self.idp_aa_check(config["service"]["idp"]) + self._idp_aa_check(config["service"]["idp"]) if "aa" in config["service"]: - self.idp_aa_check(config["service"]["aa"]) + self._idp_aa_check(config["service"]["aa"]) for key, val in config.items(): self[key] = val diff --git a/src/saml2/metadata.py b/src/saml2/metadata.py index 446bdce..2fb2cc3 100644 --- a/src/saml2/metadata.py +++ b/src/saml2/metadata.py @@ -20,7 +20,7 @@ Contains classes and functions to alleviate the handling of SAML metadata """ import httplib2 -import sys +import sys from decorator import decorator from saml2 import md, BINDING_HTTP_POST @@ -28,32 +28,35 @@ from saml2 import samlp, BINDING_HTTP_REDIRECT, BINDING_SOAP #from saml2.time_util import str_to_time from saml2.sigver import make_temp, cert_from_key_info, verify_signature from saml2.time_util import valid +from saml2.attribute_converter import ava_fro @decorator -def keep_updated(f, self, entity_id, *args, **kwargs): +def keep_updated(func, self, entity_id, *args, **kwargs): #print "In keep_updated" try: if not valid(self.entity[entity_id]["valid_until"]): self.reload_entity(entity_id) except KeyError: pass - - return f(self, entity_id, *args, **kwargs) + + return func(self, entity_id, *args, **kwargs) class MetaData(object): """ A class to manage metadata information """ - def __init__(self, xmlsec_binary=None, log=None): + def __init__(self, xmlsec_binary=None, attrconv=None, log=None): + self.log = log + self.xmlsec_binary = xmlsec_binary + self.attrconv = attrconv or [] self._loc_key = {} self._loc_bind = {} self.entity = {} self.valid_to = None self.cache_until = None - self.log = log - self.xmlsec_binary = xmlsec_binary self.http = httplib2.Http() self._import = {} - + self._wants = {} + def _vo_metadata(self, entity_descriptor, entity, tag): """ Pick out the Affiliation descriptors from an entity @@ -68,10 +71,10 @@ class MetaData(object): return members = [] - for tafd in afd: # should really never be more than one + for tafd in afd: # should really never be more than one members.extend( [member.text.strip() for member in tafd.affiliate_member]) - + if members != []: entity[tag] = members @@ -89,6 +92,9 @@ class MetaData(object): return ssds = [] + required = [] + optional = [] + #print "..... %s ..... " % entity_descriptor.entity_id for tssd in ssd: # Only want to talk to SAML 2.0 entities if samlp.NAMESPACE not in \ @@ -102,9 +108,26 @@ class MetaData(object): certs.extend(cert_from_key_info(key_desc.key_info)) certs = [make_temp(c, suffix=".der") for c in certs] + + for acs in tssd.attribute_consuming_service: + for attr in acs.requested_attribute: + print "==", attr + if attr.is_required == "true": + required.append(attr) + else: + optional.append(attr) + for acs in tssd.assertion_consumer_service: self._loc_key[acs.location] = certs - + + if required or optional: + #print "REQ",required + #print "OPT",optional + self._wants[entity_descriptor.entity_id] = (ava_fro(self.attrconv, + required), + ava_fro(self.attrconv, + optional)) + if ssds: entity[tag] = ssds @@ -136,7 +159,7 @@ class MetaData(object): certs = [make_temp(c, suffix=".der") for c in certs] for sso in tidp.single_sign_on_service: self._loc_key[sso.location] = certs - + if idps: entity[tag] = idps @@ -153,7 +176,7 @@ class MetaData(object): except AttributeError: #print "No Attribute AD: %s" % entity_descriptor.entity_id return - + aads = [] for taad in attr_auth_descr: # Remove everyone that doesn't talk SAML 2.0 @@ -168,10 +191,10 @@ class MetaData(object): #print "binding", attr_serv.binding if attr_serv.binding == BINDING_SOAP: aserv.append(attr_serv) - + if aserv == []: continue - + taad.attribute_service = aserv # gather all the certs and place them in temporary files @@ -185,12 +208,12 @@ class MetaData(object): self._loc_key[sso.location].append(certs) except KeyError: self._loc_key[sso.location] = certs - - aads.append(taad) - - if aads != []: - entity[tag] = aads + aads.append(taad) + + if aads != []: + entity[tag] = aads + def clear_from_source(self, source): for eid in self._import[source]: del self.entity[eid] @@ -202,28 +225,28 @@ class MetaData(object): return self.clear_from_source(source) if isinstance(source, basestring): - f = open(source) - self.import_metadata( f.read(), source) - f.close() - else: - self.import_external_metadata(source[0],source[1]) + fil = open(source) + self.import_metadata( fil.read(), source) + fil.close() + else: + self.import_external_metadata(source[0], source[1]) def import_metadata(self, xml_str, source): """ Import information; organization distinguish name, location and certificates from a metadata file. - + :param xml_str: The metadata as a XML string. """ - + # now = time.gmtime() entities_descriptor = md.entities_descriptor_from_string(xml_str) - + try: valid(entities_descriptor.valid_until) except AttributeError: pass - + for entity_descriptor in entities_descriptor.entity_descriptor: try: if not valid(entity_descriptor.valid_until): @@ -238,7 +261,7 @@ class MetaData(object): continue except AttributeError: pass - + try: self._import[source].append(entity_descriptor.entity_id) except KeyError: @@ -248,7 +271,7 @@ class MetaData(object): entity["valid_until"] = entities_descriptor.valid_until self._idp_metadata(entity_descriptor, entity, "idp_sso") self._sp_metadata(entity_descriptor, entity, "sp_sso") - self._aad_metadata(entity_descriptor, entity, + self._aad_metadata(entity_descriptor, entity, "attribute_authority") self._vo_metadata(entity_descriptor, entity, "affiliation") try: @@ -259,7 +282,7 @@ class MetaData(object): entity["contact"] = entity_descriptor.contact except AttributeError: pass - + def import_external_metadata(self, url, cert=None): """ Imports metadata by the use of HTTP GET. If the fingerprint is known the file will be checked for @@ -274,28 +297,28 @@ class MetaData(object): if verify_signature(content, self.xmlsec_binary, cert, "pem", "%s:%s" % (md.EntitiesDescriptor.c_namespace, md.EntitiesDescriptor.c_tag)): - self.import_metadata(content, (url,cert)) + self.import_metadata(content, (url, cert)) return True else: self.log and self.log.info("Response status: %s" % response.status) return False - + @keep_updated - def single_sign_on_services(self, entity_id, + def single_sign_on_services(self, entity_id, binding = BINDING_HTTP_REDIRECT): """ Get me all single-sign-on services that supports the specified binding version. :param entity_id: The EntityId :param binding: A binding identifier - :return: list of single-sign-on service location run by the entity + :return: list of single-sign-on service location run by the entity with the specified EntityId. """ - + # May raise KeyError idps = self.entity[entity_id]["idp_sso"] - + loc = [] #print idps for idp in idps: @@ -305,21 +328,21 @@ class MetaData(object): if binding == sso.binding: loc.append(sso.location) return loc - + @keep_updated def attribute_services(self, entity_id): try: return self.entity[entity_id]["attribute_authority"] except KeyError: return [] - + def locations(self): """ Returns all the locations that are know using this metadata file. :return: A list of IdP locations """ return self._loc_key.keys() - + def certs(self, loc): """ Get all certificates that are used by a IdP at the specified location. There can be more than one because of overlapping lifetimes @@ -340,14 +363,14 @@ class MetaData(object): return self.entity[entity_id]["affiliation"] except KeyError: return [] - + @keep_updated def consumer_url(self, entity_id, binding=BINDING_HTTP_POST, _log=None): try: ssos = self.entity[entity_id]["sp_sso"] except KeyError: raise - + # any default ? for sso in ssos: for acs in sso.assertion_consumer_service: @@ -360,7 +383,7 @@ class MetaData(object): return acs.location return None - + @keep_updated def name(self, entity_id): """ Find a name from the metadata about this entity id. @@ -368,7 +391,7 @@ class MetaData(object): ,in that order, for the organization. :param entityid: The Entity ID - :return: A name + :return: A name """ try: org = self.entity[entity_id]["organization"] @@ -386,16 +409,23 @@ class MetaData(object): name = names[0].text except KeyError: name = "" - + return name - + + @keep_updated + def wants(self, entity_id): + try: + return self._wants[entity_id] + except KeyError: + return ([], []) + @keep_updated def attribute_consumer(self, entity_id): try: ssos = self.entity[entity_id]["sp_sso"] except KeyError: return ([], []) - + required = [] optional = [] # What if there is more than one ? Can't be ? @@ -405,41 +435,41 @@ class MetaData(object): required.append(attr) else: optional.append(attr) - - return (required, optional) + return (required, optional) + def _orgname(self, org, lang="en"): if not org: return "" - for ll in [lang,None]: + for spec in [lang, None]: for name in org.organization_display_name: - if name.lang == ll: + if name.lang == spec: return name.text.strip() for name in org.organization_name: - if name.lang == ll: + if name.lang == spec: return name.text.strip() for name in org.organization_url: - if name.lang == ll: + if name.lang == spec: return name.text.strip() return "" - + def _location(self, idpsso): loc = [] for idp in idpsso: for sso in idp.single_sign_on_service: loc.append(sso.location) - + return loc - - @keep_updated - def _valid(self, entity_id): - return True - + + # @keep_updated + # def _valid(self, entity_id): + # return True + def idps(self): idps = {} for entity_id, edict in self.entity.items(): if "idp_sso" in edict: - self._valid(entity_id) +#idp_aa_check self._valid(entity_id) if "organization" in edict: name = self._orgname(edict["organization"],"en") if not name: diff --git a/src/saml2/server.py b/src/saml2/server.py index fc4c640..1b28e7b 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -22,19 +22,18 @@ or attribute authority (AA) may use to conclude its tasks. import shelve import sys -from saml2 import saml, samlp, VERSION, make_instance +from saml2 import saml, samlp, VERSION, class_name from saml2.utils import sid, decode_base64_and_inflate from saml2.utils import response_factory from saml2.utils import MissingValue, args2dict -from saml2.utils import success_status_factory, assertion_factory -from saml2.utils import OtherError, do_attribute_statement +from saml2.utils import success_status_factory +from saml2.utils import OtherError from saml2.utils import VersionMismatch, UnknownPrincipal, UnsupportedBinding from saml2.utils import status_from_exception_factory from saml2.sigver import security_context, signed_instance_factory from saml2.sigver import pre_signature_part -from saml2.time_util import instant, in_a_while from saml2.config import Config from saml2.cache import Cache from saml2.assertion import Assertion, Policy @@ -45,7 +44,7 @@ class UnknownVO(Exception): class Identifier(object): """ A class that handles identifiers of objects """ def __init__(self, dbname, entityid, voconf=None, debug=0, log=None): - self.map = shelve.open(dbname,writeback=True) + self.map = shelve.open(dbname, writeback=True) self.entityid = entityid self.voconf = voconf self.debug = debug @@ -102,11 +101,11 @@ class Identifier(object): raise UnknownVO("%s" % sp_name_qualifier) try: - format = vo_conf["nameid_format"] + nameid_format = vo_conf["nameid_format"] except KeyError: - format = saml.NAMEID_FORMAT_PERSISTENT + nameid_format = saml.NAMEID_FORMAT_PERSISTENT - return args2dict(subj_id, format=format, + return args2dict(subj_id, format=nameid_format, sp_name_qualifier=sp_name_qualifier) def persistent_nameid(self, sp_name_qualifier, userid): @@ -156,13 +155,14 @@ class Server(object): self.log = log self.debug = debug + self.ident = None if config_file: self.load_config(config_file) elif config: self.conf = config self.metadata = self.conf["metadata"] - self.sc = security_context(self.conf, log) + self.sec = security_context(self.conf, log) if cache: self.cache = Cache(cache) else: @@ -173,11 +173,11 @@ class Server(object): self.conf = Config() self.conf.load_file(config_file) if "subject_data" in self.conf: - self.id = Identifier(self.conf["subject_data"], + self.ident = Identifier(self.conf["subject_data"], self.conf["entityid"], self.conf.vo_conf, self.debug, self.log) else: - self.id = None + self.ident = None def issuer(self): """ Return an Issuer precursor """ @@ -199,7 +199,7 @@ class Server(object): response = {} request_xml = decode_base64_and_inflate(enc_request) try: - request = self.sc.correctly_signed_authn_request(request_xml) + request = self.sec.correctly_signed_authn_request(request_xml) if self.log and self.debug: self.log.info("Request was correctly signed") except Exception: @@ -324,7 +324,7 @@ class Server(object): if sign: assertion["signature"] = pre_signature_part(assertion["id"], - self.sc.my_cert, 1) + self.sec.my_cert, 1) # Store which assertion that has been sent to which SP about which # subject. @@ -335,7 +335,7 @@ class Server(object): response.update({"assertion":assertion}) - return signed_instance_factory(samlp.Response, response, self.sc) + return signed_instance_factory(samlp.Response, response, self.sec) # ------------------------------------------------------------------------ @@ -364,13 +364,12 @@ class Server(object): # ------------------------------------------------------------------------ def do_aa_response(self, consumer_url, in_response_to, sp_entity_id, - identity=None, userid="", name_id=None, ip_address="", - issuer=None, status=None, sign=False, - name_id_policy=None): + identity=None, userid="", name_id=None, status=None, + sign=False, name_id_policy=None): - name_id = self.id.construct_nameid(self.conf.aa_policy(), userid, + name_id = self.ident.construct_nameid(self.conf.aa_policy(), userid, sp_entity_id, identity) - + return self._response(consumer_url, in_response_to, sp_entity_id, identity, name_id, status, sign, policy=self.conf.aa_policy()) @@ -395,7 +394,7 @@ class Server(object): """ try: - name_id = self.id.construct_nameid(self.conf.idp_policy(), + name_id = self.ident.construct_nameid(self.conf.idp_policy(), userid, sp_entity_id, identity, name_id_policy) except IOError, exc: @@ -419,7 +418,7 @@ class Server(object): if sign: try: - return self.sc.sign_statement_using_xmlsec(response, + return self.sec.sign_statement_using_xmlsec(response, class_name(response)) except Exception, exc: response = self.error_response(destination, in_response_to, diff --git a/src/saml2/time_util.py b/src/saml2/time_util.py index b4d9aa5..6a0c4bd 100644 --- a/src/saml2/time_util.py +++ b/src/saml2/time_util.py @@ -230,3 +230,10 @@ def valid( valid_until ): return True else: return False + +def later_than(then, that): + then = str_to_time( then ) + then = str_to_time( that ) + + return then >= that + \ No newline at end of file diff --git a/src/saml2/utils.py b/src/saml2/utils.py index 9f9ceed..2f0c3f6 100644 --- a/src/saml2/utils.py +++ b/src/saml2/utils.py @@ -2,8 +2,7 @@ import time import base64 -import re -from saml2 import samlp, saml, VERSION, sigver, NAME_FORMAT_URI +from saml2 import samlp, VERSION, sigver from saml2.time_util import instant try: @@ -191,7 +190,7 @@ def response_factory(signature=False, encrypt=False, **kwargs): return args2dict(**kwargs) def _attrval(val): - if isinstance(val, list) or isinstance(val,set): + if isinstance(val, list) or isinstance(val, set): attrval = [args2dict(v) for v in val] elif val == None: attrval = None