Better identity filtering

This commit is contained in:
Roland Hedberg
2009-11-27 13:24:34 +01:00
parent 7db76f54f9
commit 64daafdc10
4 changed files with 154 additions and 77 deletions

View File

@@ -16,19 +16,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains base classes representing Saml elements.
"""Contains base classes representing SAML elements.
These codes were originally written by Jeffrey Scudder for
representing Atom elements. Takashi Matsuo had added some codes, and
changed some. Roland Hedberg changed and added some more.
Module objective: provide data classes for Saml constructs. These
classes hide the XML-ness of Saml and provide a set of native Python
Module objective: provide data classes for SAML constructs. These
classes hide the XML-ness of SAML and provide a set of native Python
classes to interact with.
Conversions to and from XML should only be necessary when the Saml classes
Conversions to and from XML should only be necessary when the SAML classes
"touch the wire" and are sent over HTTP. For this reason this module
provides methods and functions to convert Saml classes to and from strings.
provides methods and functions to convert SAML classes to and from strings.
"""
try:
@@ -123,7 +123,7 @@ class Error(Exception):
pass
class ExtensionElement(object):
"""XML which is not part of the Saml specification,
"""XML which is not part of the SAML specification,
these are called extension elements. If a classes parser
encounters an unexpected XML construct, it is translated into an
ExtensionElement instance. ExtensionElement is designed to fully
@@ -324,9 +324,9 @@ class ExtensionContainer(object):
class SamlBase(ExtensionContainer):
"""A foundation class on which Saml classes are built. It
"""A foundation class on which SAML classes are built. It
handles the parsing of attributes and children which are common to all
Saml classes. By default, the SamlBase class translates all XML child
SAML classes. By default, the SamlBase class translates all XML child
nodes into ExtensionElements.
"""

View File

@@ -26,11 +26,12 @@ from saml2.utils import kd_issuer, kd_conditions, kd_audience_restriction
from saml2.utils import sid, decode_base64_and_inflate, make_instance
from saml2.utils import kd_audience, kd_name_id, kd_assertion
from saml2.utils import kd_subject, kd_subject_confirmation, kd_response
from saml2.utils import kd_authn_statement
from saml2.utils import kd_authn_statement, MissingValue
from saml2.utils import kd_subject_confirmation_data, kd_success_status
from saml2.utils import filter_attribute_value_assertions
from saml2.utils import OtherError, do_attribute_statement
from saml2.utils import VersionMismatch, UnknownPrincipal, UnsupportedBinding
from saml2.utils import filter_on_attributes
from saml2.sigver import correctly_signed_authn_request
from saml2.sigver import pre_signature_part
@@ -201,7 +202,7 @@ class Server(object):
return in_a_while(**{"hours":1})
def do_sso_response(self, consumer_url, in_response_to,
sp_entity_id, identity, name_id=None ):
sp_entity_id, identity, name_id=None, status=None ):
""" Create a Response the follows the ??? profile.
:param consumer_url: The URL which should receive the response
@@ -240,11 +241,14 @@ class Server(object):
in_response_to=in_response_to))),
),
if not status:
status = kd_success_status()
tmp = kd_response(
issuer=self.issuer(),
in_response_to=in_response_to,
destination=consumer_url,
status=kd_success_status(),
in_response_to = in_response_to,
destination = consumer_url,
status = status,
assertion=assertion,
)
@@ -306,7 +310,8 @@ class Server(object):
return make_instance(samlp.Response, tmp)
def filter_ava(self, ava, sp_entity_id, required, optional, role=""):
def filter_ava(self, ava, sp_entity_id, required=None, optional=None,
role=""):
""" What attribute and attribute values returns depends on what
the SP has said it wants in the request or in the metadata file and
what the IdP/AA wants to release. An assumption is that what the SP
@@ -331,8 +336,8 @@ class Server(object):
if restrictions:
ava = filter_attribute_value_assertions(ava, restrictions)
if required:
pass
if required or optional:
ava = filter_on_attributes(ava, required, optional)
return ava
@@ -368,15 +373,25 @@ class Server(object):
sp_name_qualifier=name_id_policy.sp_name_qualifier)
# Do attribute filtering
(required,optional) = self.conf["metadata"].attribute_consumer(spid)
identity = self.filter_ava( identity, spid, required, optional, "idp")
resp = self.do_sso_response(
(required, optional) = self.conf["metadata"].attribute_consumer(spid)
try:
identity = self.filter_ava( identity, spid, required,
optional, "idp")
resp = self.do_sso_response(
destination, # consumer_url
in_response_to, # in_response_to
spid, # sp_entity_id
identity, # identity as dictionary
name_id,
)
except MissingValue:
resp = self.do_sso_response(
destination, # consumer_url
in_response_to, # in_response_to
spid, # sp_entity_id
name_id,
)
return ("%s" % resp).split("\n")

View File

@@ -53,6 +53,9 @@ def deflate_and_base64_encode( string_val ):
def sid(seed=""):
"""The hash of the server time + seed makes an unique SID for each session.
:param seed: A seed string
:return: The hex version of the digest
"""
ident = md5()
ident.update(repr(time.time()))
@@ -195,67 +198,83 @@ def identity_attribute(form, attribute, forward_map=None):
# default is name
return attribute.name
def filter_values(vals, attributes, required=True):
reqval = []
for rval in attributes:
for val in vals:
if rval.text == val:
reqval.append(val)
break
def filter_values(vals, required=None, optional=None):
""" Removes values from *val* that does not appear in *attributes*.
:param val: The values that are to be filtered
:param required: The requires values
:param optional: The optional values
:return: The set of values after filtering
"""
if not required and not optional:
return vals
valr = []
valo = []
if required:
if len(reqval) == len(attributes):
return reqval
rvals = [v.text for v in required]
else:
rvals = []
if optional:
ovals = [v.text for v in optional]
else:
ovals = []
for val in vals:
if val in rvals:
valr.append(val)
elif val in ovals:
valo.append(val)
valo.extend(valr)
if rvals:
if len(rvals) == len(valr):
return valo
else:
raise MissingValue("Required attribute value missing")
else:
return reqval
return valo
def filter_required(ava, required):
"""
def combine(required=None, optional=None):
res = {}
if not required:
required = []
if not optional:
optional = []
for attr in required:
part = None
for oat in optional:
if attr.name == oat.name:
part = (attr.attribute_value, oat.attribute_value)
break
if part:
res[(attr.name, attr.friendly_name)] = part
else:
res[(attr.name, attr.friendly_name)] = (attr.attribute_value, [])
for oat in optional:
tag = (oat.name, oat.friendly_name)
if tag not in res:
res[tag] = ([], oat.attribute_value)
return res
def filter_on_attributes(ava, required=None, optional=None):
""" Filter
:param required: list of RequestedAttribute instances
"""
res = {}
for attr in required:
if attr.name in ava:
if required.attribute_value:
res[attr.name] = filter_values(ava[attr.name],
required.attribute_value)
else:
res[attr.name] = ava[attr.name]
elif attr.friendly_name in ava:
if attr.attribute_value:
res[attr.friendly_name] = filter_values(
ava[attr.friendly_name],
attr.attribute_value)
else:
res[attr.friendly_name] = ava[attr.friendly_name]
comb = combine(required, optional)
for attr, vals in comb.items():
if attr[0] in ava:
res[attr[0]] = filter_values(ava[attr[0]], vals[0], vals[1])
elif attr[1] in ava:
res[attr[1]] = filter_values(ava[attr[1]], vals[0], vals[1])
else:
raise MissingValue("Required attribute missing")
return res
def filter_optional(ava, optional):
"""
:param optional: list of RequestedAttribute instances
"""
res = {}
for attr in optional:
if attr.name in ava:
if optional.attribute_value:
res[attr.name] = filter_values(ava[attr.name],
optional.attribute_value, False)
else:
res[attr.name] = ava[attr.name]
elif attr.friendly_name in ava:
if attr.attribute_value:
res[attr.friendly_name] = filter_values(
ava[attr.friendly_name],
attr.attribute_value, False)
else:
res[attr.friendly_name] = ava[attr.friendly_name]
return res
#----------------------------------------------------------------------------

View File

@@ -377,25 +377,39 @@ def test_identity_attribute_4():
# if there would be a map it would be serialNumber
assert utils.identity_attribute("friendly",a) == "serialNumber"
def test_filter_values_req_0():
def test_combine_0():
r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
friendly_name="serialNumber")
o = Attribute(name="urn:oid:2.5.4.4", name_format=NAME_FORMAT_URI,
friendly_name="surName")
comb = utils.combine([r],[o])
print comb
assert _eq(comb.keys(), [('urn:oid:2.5.4.5', 'serialNumber'),
('urn:oid:2.5.4.4', 'surName')])
assert comb[('urn:oid:2.5.4.5', 'serialNumber')] == ([], [])
assert comb[('urn:oid:2.5.4.4', 'surName')] == ([], [])
def test_filter_on_attributes_0():
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
friendly_name="serialNumber")
required = [a]
ava = { "serialNumber": ["12345"]}
ava = utils.filter_required(ava, required)
ava = utils.filter_on_attributes(ava, required)
assert ava.keys() == ["serialNumber"]
assert ava["serialNumber"] == ["12345"]
def test_filter_values_req_1():
def test_filter_on_attributes_1():
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
friendly_name="serialNumber")
required = [a]
ava = { "serialNumber": ["12345"], "givenName":["Lars"]}
ava = utils.filter_required(ava, required)
ava = utils.filter_on_attributes(ava, required)
assert ava.keys() == ["serialNumber"]
assert ava["serialNumber"] == ["12345"]
@@ -408,7 +422,7 @@ def test_filter_values_req_2():
required = [a1,a2]
ava = { "serialNumber": ["12345"], "givenName":["Lars"]}
raises(utils.MissingValue, utils.filter_required, ava, required)
raises(utils.MissingValue, utils.filter_on_attributes, ava, required)
def test_filter_values_req_3():
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
@@ -418,7 +432,7 @@ def test_filter_values_req_3():
required = [a]
ava = { "serialNumber": ["12345"]}
ava = utils.filter_required(ava, required)
ava = utils.filter_on_attributes(ava, required)
assert ava.keys() == ["serialNumber"]
assert ava["serialNumber"] == ["12345"]
@@ -430,7 +444,7 @@ def test_filter_values_req_4():
required = [a]
ava = { "serialNumber": ["12345"]}
raises(utils.MissingValue, utils.filter_required, ava, required)
raises(utils.MissingValue, utils.filter_on_attributes, ava, required)
def test_filter_values_req_5():
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
@@ -440,7 +454,7 @@ def test_filter_values_req_5():
required = [a]
ava = { "serialNumber": ["12345", "54321"]}
ava = utils.filter_required(ava, required)
ava = utils.filter_on_attributes(ava, required)
assert ava.keys() == ["serialNumber"]
assert ava["serialNumber"] == ["12345"]
@@ -452,6 +466,35 @@ def test_filter_values_req_6():
required = [a]
ava = { "serialNumber": ["12345", "54321"]}
ava = utils.filter_required(ava, required)
ava = utils.filter_on_attributes(ava, required)
assert ava.keys() == ["serialNumber"]
assert ava["serialNumber"] == ["54321"]
def test_filter_values_req_opt_0():
r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
friendly_name="serialNumber", attribute_value=[
AttributeValue(text="54321")])
o = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
friendly_name="serialNumber", attribute_value=[
AttributeValue(text="12345")])
ava = { "serialNumber": ["12345", "54321"]}
ava = utils.filter_on_attributes(ava, [r], [o])
assert ava.keys() == ["serialNumber"]
assert _eq(ava["serialNumber"], ["12345","54321"])
def test_filter_values_req_opt_1():
r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
friendly_name="serialNumber", attribute_value=[
AttributeValue(text="54321")])
o = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
friendly_name="serialNumber", attribute_value=[
AttributeValue(text="12345"),
AttributeValue(text="abcd0")])
ava = { "serialNumber": ["12345", "54321"]}
ava = utils.filter_on_attributes(ava, [r], [o])
assert ava.keys() == ["serialNumber"]
assert _eq(ava["serialNumber"], ["12345","54321"])