Better identity filtering
This commit is contained in:
@@ -16,19 +16,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains base classes representing Saml elements.
|
||||
"""Contains base classes representing SAML elements.
|
||||
|
||||
These codes were originally written by Jeffrey Scudder for
|
||||
representing Atom elements. Takashi Matsuo had added some codes, and
|
||||
changed some. Roland Hedberg changed and added some more.
|
||||
|
||||
Module objective: provide data classes for Saml constructs. These
|
||||
classes hide the XML-ness of Saml and provide a set of native Python
|
||||
Module objective: provide data classes for SAML constructs. These
|
||||
classes hide the XML-ness of SAML and provide a set of native Python
|
||||
classes to interact with.
|
||||
|
||||
Conversions to and from XML should only be necessary when the Saml classes
|
||||
Conversions to and from XML should only be necessary when the SAML classes
|
||||
"touch the wire" and are sent over HTTP. For this reason this module
|
||||
provides methods and functions to convert Saml classes to and from strings.
|
||||
provides methods and functions to convert SAML classes to and from strings.
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -123,7 +123,7 @@ class Error(Exception):
|
||||
pass
|
||||
|
||||
class ExtensionElement(object):
|
||||
"""XML which is not part of the Saml specification,
|
||||
"""XML which is not part of the SAML specification,
|
||||
these are called extension elements. If a classes parser
|
||||
encounters an unexpected XML construct, it is translated into an
|
||||
ExtensionElement instance. ExtensionElement is designed to fully
|
||||
@@ -324,9 +324,9 @@ class ExtensionContainer(object):
|
||||
|
||||
|
||||
class SamlBase(ExtensionContainer):
|
||||
"""A foundation class on which Saml classes are built. It
|
||||
"""A foundation class on which SAML classes are built. It
|
||||
handles the parsing of attributes and children which are common to all
|
||||
Saml classes. By default, the SamlBase class translates all XML child
|
||||
SAML classes. By default, the SamlBase class translates all XML child
|
||||
nodes into ExtensionElements.
|
||||
"""
|
||||
|
||||
|
||||
@@ -26,11 +26,12 @@ from saml2.utils import kd_issuer, kd_conditions, kd_audience_restriction
|
||||
from saml2.utils import sid, decode_base64_and_inflate, make_instance
|
||||
from saml2.utils import kd_audience, kd_name_id, kd_assertion
|
||||
from saml2.utils import kd_subject, kd_subject_confirmation, kd_response
|
||||
from saml2.utils import kd_authn_statement
|
||||
from saml2.utils import kd_authn_statement, MissingValue
|
||||
from saml2.utils import kd_subject_confirmation_data, kd_success_status
|
||||
from saml2.utils import filter_attribute_value_assertions
|
||||
from saml2.utils import OtherError, do_attribute_statement
|
||||
from saml2.utils import VersionMismatch, UnknownPrincipal, UnsupportedBinding
|
||||
from saml2.utils import filter_on_attributes
|
||||
|
||||
from saml2.sigver import correctly_signed_authn_request
|
||||
from saml2.sigver import pre_signature_part
|
||||
@@ -201,7 +202,7 @@ class Server(object):
|
||||
return in_a_while(**{"hours":1})
|
||||
|
||||
def do_sso_response(self, consumer_url, in_response_to,
|
||||
sp_entity_id, identity, name_id=None ):
|
||||
sp_entity_id, identity, name_id=None, status=None ):
|
||||
""" Create a Response the follows the ??? profile.
|
||||
|
||||
:param consumer_url: The URL which should receive the response
|
||||
@@ -240,11 +241,14 @@ class Server(object):
|
||||
in_response_to=in_response_to))),
|
||||
),
|
||||
|
||||
if not status:
|
||||
status = kd_success_status()
|
||||
|
||||
tmp = kd_response(
|
||||
issuer=self.issuer(),
|
||||
in_response_to=in_response_to,
|
||||
destination=consumer_url,
|
||||
status=kd_success_status(),
|
||||
in_response_to = in_response_to,
|
||||
destination = consumer_url,
|
||||
status = status,
|
||||
assertion=assertion,
|
||||
)
|
||||
|
||||
@@ -306,7 +310,8 @@ class Server(object):
|
||||
|
||||
return make_instance(samlp.Response, tmp)
|
||||
|
||||
def filter_ava(self, ava, sp_entity_id, required, optional, role=""):
|
||||
def filter_ava(self, ava, sp_entity_id, required=None, optional=None,
|
||||
role=""):
|
||||
""" What attribute and attribute values returns depends on what
|
||||
the SP has said it wants in the request or in the metadata file and
|
||||
what the IdP/AA wants to release. An assumption is that what the SP
|
||||
@@ -331,8 +336,8 @@ class Server(object):
|
||||
if restrictions:
|
||||
ava = filter_attribute_value_assertions(ava, restrictions)
|
||||
|
||||
if required:
|
||||
pass
|
||||
if required or optional:
|
||||
ava = filter_on_attributes(ava, required, optional)
|
||||
|
||||
return ava
|
||||
|
||||
@@ -368,15 +373,25 @@ class Server(object):
|
||||
sp_name_qualifier=name_id_policy.sp_name_qualifier)
|
||||
|
||||
# Do attribute filtering
|
||||
(required,optional) = self.conf["metadata"].attribute_consumer(spid)
|
||||
identity = self.filter_ava( identity, spid, required, optional, "idp")
|
||||
|
||||
resp = self.do_sso_response(
|
||||
(required, optional) = self.conf["metadata"].attribute_consumer(spid)
|
||||
try:
|
||||
identity = self.filter_ava( identity, spid, required,
|
||||
optional, "idp")
|
||||
resp = self.do_sso_response(
|
||||
destination, # consumer_url
|
||||
in_response_to, # in_response_to
|
||||
spid, # sp_entity_id
|
||||
identity, # identity as dictionary
|
||||
name_id,
|
||||
)
|
||||
except MissingValue:
|
||||
resp = self.do_sso_response(
|
||||
destination, # consumer_url
|
||||
in_response_to, # in_response_to
|
||||
spid, # sp_entity_id
|
||||
name_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
return ("%s" % resp).split("\n")
|
||||
@@ -53,6 +53,9 @@ def deflate_and_base64_encode( string_val ):
|
||||
|
||||
def sid(seed=""):
|
||||
"""The hash of the server time + seed makes an unique SID for each session.
|
||||
|
||||
:param seed: A seed string
|
||||
:return: The hex version of the digest
|
||||
"""
|
||||
ident = md5()
|
||||
ident.update(repr(time.time()))
|
||||
@@ -195,67 +198,83 @@ def identity_attribute(form, attribute, forward_map=None):
|
||||
# default is name
|
||||
return attribute.name
|
||||
|
||||
def filter_values(vals, attributes, required=True):
|
||||
reqval = []
|
||||
for rval in attributes:
|
||||
for val in vals:
|
||||
if rval.text == val:
|
||||
reqval.append(val)
|
||||
break
|
||||
|
||||
def filter_values(vals, required=None, optional=None):
|
||||
""" Removes values from *val* that does not appear in *attributes*.
|
||||
|
||||
:param val: The values that are to be filtered
|
||||
:param required: The requires values
|
||||
:param optional: The optional values
|
||||
:return: The set of values after filtering
|
||||
"""
|
||||
|
||||
if not required and not optional:
|
||||
return vals
|
||||
|
||||
valr = []
|
||||
valo = []
|
||||
if required:
|
||||
if len(reqval) == len(attributes):
|
||||
return reqval
|
||||
rvals = [v.text for v in required]
|
||||
else:
|
||||
rvals = []
|
||||
if optional:
|
||||
ovals = [v.text for v in optional]
|
||||
else:
|
||||
ovals = []
|
||||
for val in vals:
|
||||
if val in rvals:
|
||||
valr.append(val)
|
||||
elif val in ovals:
|
||||
valo.append(val)
|
||||
|
||||
valo.extend(valr)
|
||||
if rvals:
|
||||
if len(rvals) == len(valr):
|
||||
return valo
|
||||
else:
|
||||
raise MissingValue("Required attribute value missing")
|
||||
else:
|
||||
return reqval
|
||||
return valo
|
||||
|
||||
def filter_required(ava, required):
|
||||
"""
|
||||
def combine(required=None, optional=None):
|
||||
res = {}
|
||||
if not required:
|
||||
required = []
|
||||
if not optional:
|
||||
optional = []
|
||||
for attr in required:
|
||||
part = None
|
||||
for oat in optional:
|
||||
if attr.name == oat.name:
|
||||
part = (attr.attribute_value, oat.attribute_value)
|
||||
break
|
||||
if part:
|
||||
res[(attr.name, attr.friendly_name)] = part
|
||||
else:
|
||||
res[(attr.name, attr.friendly_name)] = (attr.attribute_value, [])
|
||||
|
||||
for oat in optional:
|
||||
tag = (oat.name, oat.friendly_name)
|
||||
if tag not in res:
|
||||
res[tag] = ([], oat.attribute_value)
|
||||
|
||||
return res
|
||||
|
||||
def filter_on_attributes(ava, required=None, optional=None):
|
||||
""" Filter
|
||||
:param required: list of RequestedAttribute instances
|
||||
"""
|
||||
res = {}
|
||||
for attr in required:
|
||||
if attr.name in ava:
|
||||
if required.attribute_value:
|
||||
res[attr.name] = filter_values(ava[attr.name],
|
||||
required.attribute_value)
|
||||
else:
|
||||
res[attr.name] = ava[attr.name]
|
||||
elif attr.friendly_name in ava:
|
||||
if attr.attribute_value:
|
||||
res[attr.friendly_name] = filter_values(
|
||||
ava[attr.friendly_name],
|
||||
attr.attribute_value)
|
||||
else:
|
||||
res[attr.friendly_name] = ava[attr.friendly_name]
|
||||
comb = combine(required, optional)
|
||||
for attr, vals in comb.items():
|
||||
if attr[0] in ava:
|
||||
res[attr[0]] = filter_values(ava[attr[0]], vals[0], vals[1])
|
||||
elif attr[1] in ava:
|
||||
res[attr[1]] = filter_values(ava[attr[1]], vals[0], vals[1])
|
||||
else:
|
||||
raise MissingValue("Required attribute missing")
|
||||
|
||||
return res
|
||||
|
||||
def filter_optional(ava, optional):
|
||||
"""
|
||||
:param optional: list of RequestedAttribute instances
|
||||
"""
|
||||
res = {}
|
||||
for attr in optional:
|
||||
if attr.name in ava:
|
||||
if optional.attribute_value:
|
||||
res[attr.name] = filter_values(ava[attr.name],
|
||||
optional.attribute_value, False)
|
||||
else:
|
||||
res[attr.name] = ava[attr.name]
|
||||
elif attr.friendly_name in ava:
|
||||
if attr.attribute_value:
|
||||
res[attr.friendly_name] = filter_values(
|
||||
ava[attr.friendly_name],
|
||||
attr.attribute_value, False)
|
||||
else:
|
||||
res[attr.friendly_name] = ava[attr.friendly_name]
|
||||
|
||||
return res
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -377,25 +377,39 @@ def test_identity_attribute_4():
|
||||
# if there would be a map it would be serialNumber
|
||||
assert utils.identity_attribute("friendly",a) == "serialNumber"
|
||||
|
||||
def test_filter_values_req_0():
|
||||
def test_combine_0():
|
||||
r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="serialNumber")
|
||||
o = Attribute(name="urn:oid:2.5.4.4", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="surName")
|
||||
|
||||
comb = utils.combine([r],[o])
|
||||
print comb
|
||||
assert _eq(comb.keys(), [('urn:oid:2.5.4.5', 'serialNumber'),
|
||||
('urn:oid:2.5.4.4', 'surName')])
|
||||
assert comb[('urn:oid:2.5.4.5', 'serialNumber')] == ([], [])
|
||||
assert comb[('urn:oid:2.5.4.4', 'surName')] == ([], [])
|
||||
|
||||
|
||||
def test_filter_on_attributes_0():
|
||||
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="serialNumber")
|
||||
|
||||
required = [a]
|
||||
ava = { "serialNumber": ["12345"]}
|
||||
|
||||
ava = utils.filter_required(ava, required)
|
||||
ava = utils.filter_on_attributes(ava, required)
|
||||
assert ava.keys() == ["serialNumber"]
|
||||
assert ava["serialNumber"] == ["12345"]
|
||||
|
||||
def test_filter_values_req_1():
|
||||
def test_filter_on_attributes_1():
|
||||
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="serialNumber")
|
||||
|
||||
required = [a]
|
||||
ava = { "serialNumber": ["12345"], "givenName":["Lars"]}
|
||||
|
||||
ava = utils.filter_required(ava, required)
|
||||
ava = utils.filter_on_attributes(ava, required)
|
||||
assert ava.keys() == ["serialNumber"]
|
||||
assert ava["serialNumber"] == ["12345"]
|
||||
|
||||
@@ -408,7 +422,7 @@ def test_filter_values_req_2():
|
||||
required = [a1,a2]
|
||||
ava = { "serialNumber": ["12345"], "givenName":["Lars"]}
|
||||
|
||||
raises(utils.MissingValue, utils.filter_required, ava, required)
|
||||
raises(utils.MissingValue, utils.filter_on_attributes, ava, required)
|
||||
|
||||
def test_filter_values_req_3():
|
||||
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
@@ -418,7 +432,7 @@ def test_filter_values_req_3():
|
||||
required = [a]
|
||||
ava = { "serialNumber": ["12345"]}
|
||||
|
||||
ava = utils.filter_required(ava, required)
|
||||
ava = utils.filter_on_attributes(ava, required)
|
||||
assert ava.keys() == ["serialNumber"]
|
||||
assert ava["serialNumber"] == ["12345"]
|
||||
|
||||
@@ -430,7 +444,7 @@ def test_filter_values_req_4():
|
||||
required = [a]
|
||||
ava = { "serialNumber": ["12345"]}
|
||||
|
||||
raises(utils.MissingValue, utils.filter_required, ava, required)
|
||||
raises(utils.MissingValue, utils.filter_on_attributes, ava, required)
|
||||
|
||||
def test_filter_values_req_5():
|
||||
a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
@@ -440,7 +454,7 @@ def test_filter_values_req_5():
|
||||
required = [a]
|
||||
ava = { "serialNumber": ["12345", "54321"]}
|
||||
|
||||
ava = utils.filter_required(ava, required)
|
||||
ava = utils.filter_on_attributes(ava, required)
|
||||
assert ava.keys() == ["serialNumber"]
|
||||
assert ava["serialNumber"] == ["12345"]
|
||||
|
||||
@@ -452,6 +466,35 @@ def test_filter_values_req_6():
|
||||
required = [a]
|
||||
ava = { "serialNumber": ["12345", "54321"]}
|
||||
|
||||
ava = utils.filter_required(ava, required)
|
||||
ava = utils.filter_on_attributes(ava, required)
|
||||
assert ava.keys() == ["serialNumber"]
|
||||
assert ava["serialNumber"] == ["54321"]
|
||||
|
||||
def test_filter_values_req_opt_0():
|
||||
r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="serialNumber", attribute_value=[
|
||||
AttributeValue(text="54321")])
|
||||
o = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="serialNumber", attribute_value=[
|
||||
AttributeValue(text="12345")])
|
||||
|
||||
ava = { "serialNumber": ["12345", "54321"]}
|
||||
|
||||
ava = utils.filter_on_attributes(ava, [r], [o])
|
||||
assert ava.keys() == ["serialNumber"]
|
||||
assert _eq(ava["serialNumber"], ["12345","54321"])
|
||||
|
||||
def test_filter_values_req_opt_1():
|
||||
r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="serialNumber", attribute_value=[
|
||||
AttributeValue(text="54321")])
|
||||
o = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI,
|
||||
friendly_name="serialNumber", attribute_value=[
|
||||
AttributeValue(text="12345"),
|
||||
AttributeValue(text="abcd0")])
|
||||
|
||||
ava = { "serialNumber": ["12345", "54321"]}
|
||||
|
||||
ava = utils.filter_on_attributes(ava, [r], [o])
|
||||
assert ava.keys() == ["serialNumber"]
|
||||
assert _eq(ava["serialNumber"], ["12345","54321"])
|
||||
|
||||
Reference in New Issue
Block a user