diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index c59cf33..ab3d839 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -228,6 +228,35 @@ class ExtensionElement(object): return results + def loadd(self, ava): + """ expects a special set of keys """ + + if "attributes" in ava: + for key, val in ava["attributes"].items(): + self.attributes[key] = val + + try: + self.tag = ava["tag"] + except KeyError: + if not self.tag: + raise KeyError("ExtensionElement must have a tag") + + try: + self.namespace = ava["namespace"] + except KeyError: + if not self.namespace: + raise KeyError("ExtensionElement must belong to a namespace") + + try: + self.text = ava["text"] + except KeyError: + pass + + if "children" in ava: + for item in ava["children"]: + self.children.append(ExtensionElement(item["tag"]).loadd(item)) + + return self def extension_element_from_string(xml_string): element_tree = ElementTree.fromstring(xml_string) @@ -526,7 +555,7 @@ class SamlBase(ExtensionContainer): elif val == None: pass else: - raise ValueError( "Type it shouldn't be '%s'" % (val,)) + raise ValueError( "Type shouldn't be '%s'" % (val,)) return self @@ -557,6 +586,14 @@ class SamlBase(ExtensionContainer): base64encode) setattr(self, prop, cis) + if "extension_elements" in ava: + for item in ava["extension_elements"]: + self.extension_elements.append(ExtensionElement(item["tag"]).loadd(item)) + + if "extension_attributes" in ava: + for key, val in ava["extension_attributes"].items(): + self.extension_attributes[key] = val + return self