Handle 'null' attribute value

This commit is contained in:
Roland Hedberg
2013-01-11 12:17:54 +01:00
parent 7f4d618124
commit 78e619c521
2 changed files with 89 additions and 17 deletions

View File

@@ -20,6 +20,7 @@ XSI_NAMESPACE = 'http://www.w3.org/2001/XMLSchema-instance'
XS_NAMESPACE = 'http://www.w3.org/2001/XMLSchema'
XSI_TYPE = '{%s}type' % XSI_NAMESPACE
XSI_NIL = '{%s}nil' % XSI_NAMESPACE
NAMEID_FORMAT_EMAILADDRESS = (
"urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress")
@@ -97,19 +98,58 @@ def _verify_value_type(typ, val):
import base64
return base64.decodestring(val)
TYPE_EXTENSION = '{%s}type' % XSI_NAMESPACE
class AttributeValueBase(SamlBase):
def __setattr__(self, key, value):
if key == "text":
self.set_text(value)
else:
SamlBase.__setattr__(self,key, value)
def __init__(self,
text=None,
extension_elements=None,
extension_attributes=None,
):
self._extatt = {}
SamlBase.__init__(self,
text=None,
extension_elements=extension_elements,
extension_attributes=extension_attributes,
)
if self._extatt:
self.extension_attributes = self._extatt
if not text:
self.extension_attributes = {XSI_NIL: 'true'}
else:
self.set_text(text)
def verify(self):
if not self.text:
assert self.extension_attributes
assert self.extension_attributes[XSI_NIL] == "true"
return True
else:
SamlBase.verify(self)
def set_type(self, typ):
self.extension_attributes[TYPE_EXTENSION] = typ
try:
self.extension_attributes[XSI_TYPE] = typ
except AttributeError:
self._extatt[XSI_TYPE] = typ
def get_type(self):
try:
return self.extension_attributes[TYPE_EXTENSION]
except KeyError:
return ""
return self.extension_attributes[XSI_TYPE]
except (KeyError, AttributeError):
try:
return self._extatt[XSI_TYPE]
except KeyError:
return ""
def set_text(self, val, base64encode=False):
typ = self.get_type()
if base64encode:
@@ -120,6 +160,8 @@ class AttributeValueBase(SamlBase):
if isinstance(val, basestring):
if not typ:
self.set_type("xs:string")
else:
assert typ == "xs:string"
elif isinstance(val, bool):
if val:
val = "true"
@@ -127,20 +169,33 @@ class AttributeValueBase(SamlBase):
val = "false"
if not typ:
self.set_type("xs:boolean")
else:
assert typ == "xs:boolean"
elif isinstance(val, int):
val = str(val)
if not typ:
self.set_type("xs:integer")
else:
assert typ == "xs:integer"
elif isinstance(val, float):
val = str(val)
if not typ:
self.set_type("xs:float")
elif val is None:
else:
assert typ == "xs:float"
elif not val:
try:
self.extension_attributes[XSI_TYPE] = typ
except AttributeError:
self._extatt[XSI_TYPE] = typ
val = ""
else:
raise ValueError
setattr(self, "text", val)
if typ == "xs:anyType":
pass
else:
raise ValueError
SamlBase.__setattr__(self, "text", val)
return self
def harvest_element_tree(self, tree):
@@ -153,7 +208,7 @@ class AttributeValueBase(SamlBase):
#print "set_text:", tree.text
self.set_text(tree.text)
try:
typ = self.extension_attributes[TYPE_EXTENSION]
typ = self.extension_attributes[XSI_TYPE]
_verify_value_type(typ, getattr(self, "text"))
except KeyError:
pass

View File

@@ -1,11 +1,14 @@
import calendar
import sys
import urlparse
import re
import time_util
import struct
import base64
# Also defined in saml2.saml but can't import from there
XSI_NAMESPACE = 'http://www.w3.org/2001/XMLSchema-instance'
XSI_NIL = '{%s}nil' % XSI_NAMESPACE
# ---------------------------------------------------------
class NotValid(Exception):
pass
@@ -48,24 +51,27 @@ def valid_any_uri(item):
# raise NotValid("AnyURI")
return True
def valid_date_time(item):
try:
time_util.str_to_time(item)
except Exception:
raise NotValid("dateTime")
return True
def valid_url(url):
try:
part = urlparse.urlparse(url)
_ = urlparse.urlparse(url)
except Exception:
raise NotValid("URL")
# if part[1] == "localhost" or part[1] == "127.0.0.1":
# raise NotValid("URL")
return True
def validate_on_or_after(not_on_or_after, slack):
if not_on_or_after:
now = time_util.utc_now()
@@ -309,6 +315,11 @@ def valid_instance(instance):
instclass = instance.__class__
class_name = instclass.__name__
if instance.text:
_has_val = True
else:
_has_val = False
if instclass.c_value_type and instance.text:
try:
validate_value_type(instance.text.strip(),
@@ -356,6 +367,7 @@ def valid_instance(instance):
_cmin = _cmax = _card = None
if value:
_has_val = True
if isinstance(value, list):
_list = True
vlen = len(value)
@@ -387,6 +399,11 @@ def valid_instance(instance):
"Class '%s' instance cardinality error: %s" % \
(class_name, "too few values on %s" % name))
if not _has_val:
# Not allow unless xsi:nil="true"
assert instance.extension_attributes
assert instance.extension_attributes[XSI_NIL] == "true"
return True
def valid_domain_name(dns_name):