From 64daafdc10e962a761cf57037115a3471b890b6f Mon Sep 17 00:00:00 2001 From: Roland Hedberg Date: Fri, 27 Nov 2009 13:24:34 +0100 Subject: [PATCH] Better identity filtering --- src/saml2/__init__.py | 16 +++--- src/saml2/server.py | 39 +++++++++----- src/saml2/utils.py | 115 ++++++++++++++++++++++++------------------ tests/test_utils.py | 61 ++++++++++++++++++---- 4 files changed, 154 insertions(+), 77 deletions(-) diff --git a/src/saml2/__init__.py b/src/saml2/__init__.py index e16b1b5..644213d 100644 --- a/src/saml2/__init__.py +++ b/src/saml2/__init__.py @@ -16,19 +16,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Contains base classes representing Saml elements. +"""Contains base classes representing SAML elements. These codes were originally written by Jeffrey Scudder for representing Atom elements. Takashi Matsuo had added some codes, and changed some. Roland Hedberg changed and added some more. - Module objective: provide data classes for Saml constructs. These - classes hide the XML-ness of Saml and provide a set of native Python + Module objective: provide data classes for SAML constructs. These + classes hide the XML-ness of SAML and provide a set of native Python classes to interact with. - Conversions to and from XML should only be necessary when the Saml classes + Conversions to and from XML should only be necessary when the SAML classes "touch the wire" and are sent over HTTP. For this reason this module - provides methods and functions to convert Saml classes to and from strings. + provides methods and functions to convert SAML classes to and from strings. """ try: @@ -123,7 +123,7 @@ class Error(Exception): pass class ExtensionElement(object): - """XML which is not part of the Saml specification, + """XML which is not part of the SAML specification, these are called extension elements. If a classes parser encounters an unexpected XML construct, it is translated into an ExtensionElement instance. ExtensionElement is designed to fully @@ -324,9 +324,9 @@ class ExtensionContainer(object): class SamlBase(ExtensionContainer): - """A foundation class on which Saml classes are built. It + """A foundation class on which SAML classes are built. It handles the parsing of attributes and children which are common to all - Saml classes. By default, the SamlBase class translates all XML child + SAML classes. By default, the SamlBase class translates all XML child nodes into ExtensionElements. """ diff --git a/src/saml2/server.py b/src/saml2/server.py index 6f106ca..ce8d7f3 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -26,11 +26,12 @@ from saml2.utils import kd_issuer, kd_conditions, kd_audience_restriction from saml2.utils import sid, decode_base64_and_inflate, make_instance from saml2.utils import kd_audience, kd_name_id, kd_assertion from saml2.utils import kd_subject, kd_subject_confirmation, kd_response -from saml2.utils import kd_authn_statement +from saml2.utils import kd_authn_statement, MissingValue from saml2.utils import kd_subject_confirmation_data, kd_success_status from saml2.utils import filter_attribute_value_assertions from saml2.utils import OtherError, do_attribute_statement from saml2.utils import VersionMismatch, UnknownPrincipal, UnsupportedBinding +from saml2.utils import filter_on_attributes from saml2.sigver import correctly_signed_authn_request from saml2.sigver import pre_signature_part @@ -201,7 +202,7 @@ class Server(object): return in_a_while(**{"hours":1}) def do_sso_response(self, consumer_url, in_response_to, - sp_entity_id, identity, name_id=None ): + sp_entity_id, identity, name_id=None, status=None ): """ Create a Response the follows the ??? profile. :param consumer_url: The URL which should receive the response @@ -240,11 +241,14 @@ class Server(object): in_response_to=in_response_to))), ), + if not status: + status = kd_success_status() + tmp = kd_response( issuer=self.issuer(), - in_response_to=in_response_to, - destination=consumer_url, - status=kd_success_status(), + in_response_to = in_response_to, + destination = consumer_url, + status = status, assertion=assertion, ) @@ -306,7 +310,8 @@ class Server(object): return make_instance(samlp.Response, tmp) - def filter_ava(self, ava, sp_entity_id, required, optional, role=""): + def filter_ava(self, ava, sp_entity_id, required=None, optional=None, + role=""): """ What attribute and attribute values returns depends on what the SP has said it wants in the request or in the metadata file and what the IdP/AA wants to release. An assumption is that what the SP @@ -331,8 +336,8 @@ class Server(object): if restrictions: ava = filter_attribute_value_assertions(ava, restrictions) - if required: - pass + if required or optional: + ava = filter_on_attributes(ava, required, optional) return ava @@ -368,15 +373,25 @@ class Server(object): sp_name_qualifier=name_id_policy.sp_name_qualifier) # Do attribute filtering - (required,optional) = self.conf["metadata"].attribute_consumer(spid) - identity = self.filter_ava( identity, spid, required, optional, "idp") - - resp = self.do_sso_response( + (required, optional) = self.conf["metadata"].attribute_consumer(spid) + try: + identity = self.filter_ava( identity, spid, required, + optional, "idp") + resp = self.do_sso_response( destination, # consumer_url in_response_to, # in_response_to spid, # sp_entity_id identity, # identity as dictionary name_id, ) + except MissingValue: + resp = self.do_sso_response( + destination, # consumer_url + in_response_to, # in_response_to + spid, # sp_entity_id + name_id, + ) + + return ("%s" % resp).split("\n") \ No newline at end of file diff --git a/src/saml2/utils.py b/src/saml2/utils.py index fc24a37..2b841b8 100644 --- a/src/saml2/utils.py +++ b/src/saml2/utils.py @@ -53,6 +53,9 @@ def deflate_and_base64_encode( string_val ): def sid(seed=""): """The hash of the server time + seed makes an unique SID for each session. + + :param seed: A seed string + :return: The hex version of the digest """ ident = md5() ident.update(repr(time.time())) @@ -195,67 +198,83 @@ def identity_attribute(form, attribute, forward_map=None): # default is name return attribute.name -def filter_values(vals, attributes, required=True): - reqval = [] - for rval in attributes: - for val in vals: - if rval.text == val: - reqval.append(val) - break - +def filter_values(vals, required=None, optional=None): + """ Removes values from *val* that does not appear in *attributes*. + + :param val: The values that are to be filtered + :param required: The requires values + :param optional: The optional values + :return: The set of values after filtering + """ + + if not required and not optional: + return vals + + valr = [] + valo = [] if required: - if len(reqval) == len(attributes): - return reqval + rvals = [v.text for v in required] + else: + rvals = [] + if optional: + ovals = [v.text for v in optional] + else: + ovals = [] + for val in vals: + if val in rvals: + valr.append(val) + elif val in ovals: + valo.append(val) + + valo.extend(valr) + if rvals: + if len(rvals) == len(valr): + return valo else: raise MissingValue("Required attribute value missing") else: - return reqval + return valo -def filter_required(ava, required): - """ +def combine(required=None, optional=None): + res = {} + if not required: + required = [] + if not optional: + optional = [] + for attr in required: + part = None + for oat in optional: + if attr.name == oat.name: + part = (attr.attribute_value, oat.attribute_value) + break + if part: + res[(attr.name, attr.friendly_name)] = part + else: + res[(attr.name, attr.friendly_name)] = (attr.attribute_value, []) + + for oat in optional: + tag = (oat.name, oat.friendly_name) + if tag not in res: + res[tag] = ([], oat.attribute_value) + + return res + +def filter_on_attributes(ava, required=None, optional=None): + """ Filter :param required: list of RequestedAttribute instances """ res = {} - for attr in required: - if attr.name in ava: - if required.attribute_value: - res[attr.name] = filter_values(ava[attr.name], - required.attribute_value) - else: - res[attr.name] = ava[attr.name] - elif attr.friendly_name in ava: - if attr.attribute_value: - res[attr.friendly_name] = filter_values( - ava[attr.friendly_name], - attr.attribute_value) - else: - res[attr.friendly_name] = ava[attr.friendly_name] + comb = combine(required, optional) + for attr, vals in comb.items(): + if attr[0] in ava: + res[attr[0]] = filter_values(ava[attr[0]], vals[0], vals[1]) + elif attr[1] in ava: + res[attr[1]] = filter_values(ava[attr[1]], vals[0], vals[1]) else: raise MissingValue("Required attribute missing") return res -def filter_optional(ava, optional): - """ - :param optional: list of RequestedAttribute instances - """ - res = {} - for attr in optional: - if attr.name in ava: - if optional.attribute_value: - res[attr.name] = filter_values(ava[attr.name], - optional.attribute_value, False) - else: - res[attr.name] = ava[attr.name] - elif attr.friendly_name in ava: - if attr.attribute_value: - res[attr.friendly_name] = filter_values( - ava[attr.friendly_name], - attr.attribute_value, False) - else: - res[attr.friendly_name] = ava[attr.friendly_name] - - return res #---------------------------------------------------------------------------- diff --git a/tests/test_utils.py b/tests/test_utils.py index b71554a..eb720d1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -377,25 +377,39 @@ def test_identity_attribute_4(): # if there would be a map it would be serialNumber assert utils.identity_attribute("friendly",a) == "serialNumber" -def test_filter_values_req_0(): +def test_combine_0(): + r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber") + o = Attribute(name="urn:oid:2.5.4.4", name_format=NAME_FORMAT_URI, + friendly_name="surName") + + comb = utils.combine([r],[o]) + print comb + assert _eq(comb.keys(), [('urn:oid:2.5.4.5', 'serialNumber'), + ('urn:oid:2.5.4.4', 'surName')]) + assert comb[('urn:oid:2.5.4.5', 'serialNumber')] == ([], []) + assert comb[('urn:oid:2.5.4.4', 'surName')] == ([], []) + + +def test_filter_on_attributes_0(): a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, friendly_name="serialNumber") required = [a] ava = { "serialNumber": ["12345"]} - ava = utils.filter_required(ava, required) + ava = utils.filter_on_attributes(ava, required) assert ava.keys() == ["serialNumber"] assert ava["serialNumber"] == ["12345"] -def test_filter_values_req_1(): +def test_filter_on_attributes_1(): a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, friendly_name="serialNumber") required = [a] ava = { "serialNumber": ["12345"], "givenName":["Lars"]} - ava = utils.filter_required(ava, required) + ava = utils.filter_on_attributes(ava, required) assert ava.keys() == ["serialNumber"] assert ava["serialNumber"] == ["12345"] @@ -408,7 +422,7 @@ def test_filter_values_req_2(): required = [a1,a2] ava = { "serialNumber": ["12345"], "givenName":["Lars"]} - raises(utils.MissingValue, utils.filter_required, ava, required) + raises(utils.MissingValue, utils.filter_on_attributes, ava, required) def test_filter_values_req_3(): a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, @@ -418,7 +432,7 @@ def test_filter_values_req_3(): required = [a] ava = { "serialNumber": ["12345"]} - ava = utils.filter_required(ava, required) + ava = utils.filter_on_attributes(ava, required) assert ava.keys() == ["serialNumber"] assert ava["serialNumber"] == ["12345"] @@ -430,7 +444,7 @@ def test_filter_values_req_4(): required = [a] ava = { "serialNumber": ["12345"]} - raises(utils.MissingValue, utils.filter_required, ava, required) + raises(utils.MissingValue, utils.filter_on_attributes, ava, required) def test_filter_values_req_5(): a = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, @@ -440,7 +454,7 @@ def test_filter_values_req_5(): required = [a] ava = { "serialNumber": ["12345", "54321"]} - ava = utils.filter_required(ava, required) + ava = utils.filter_on_attributes(ava, required) assert ava.keys() == ["serialNumber"] assert ava["serialNumber"] == ["12345"] @@ -452,6 +466,35 @@ def test_filter_values_req_6(): required = [a] ava = { "serialNumber": ["12345", "54321"]} - ava = utils.filter_required(ava, required) + ava = utils.filter_on_attributes(ava, required) assert ava.keys() == ["serialNumber"] assert ava["serialNumber"] == ["54321"] + +def test_filter_values_req_opt_0(): + r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", attribute_value=[ + AttributeValue(text="54321")]) + o = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", attribute_value=[ + AttributeValue(text="12345")]) + + ava = { "serialNumber": ["12345", "54321"]} + + ava = utils.filter_on_attributes(ava, [r], [o]) + assert ava.keys() == ["serialNumber"] + assert _eq(ava["serialNumber"], ["12345","54321"]) + +def test_filter_values_req_opt_1(): + r = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", attribute_value=[ + AttributeValue(text="54321")]) + o = Attribute(name="urn:oid:2.5.4.5", name_format=NAME_FORMAT_URI, + friendly_name="serialNumber", attribute_value=[ + AttributeValue(text="12345"), + AttributeValue(text="abcd0")]) + + ava = { "serialNumber": ["12345", "54321"]} + + ava = utils.filter_on_attributes(ava, [r], [o]) + assert ava.keys() == ["serialNumber"] + assert _eq(ava["serialNumber"], ["12345","54321"])