diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index b539bef..4047b81 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -513,9 +513,10 @@ class Base(Entity): kwargs = {"entity_id": self.config.entityid, "attribute_converters": self.config.attribute_converters} - return self._parse_response(response, AssertionIDResponse, "", binding, + res = self._parse_response(response, AssertionIDResponse, "", binding, **kwargs) - + return res + # ------------------------------------------------------------------------ def parse_attribute_query_response(self, response, binding): diff --git a/src/saml2/config.py b/src/saml2/config.py index 08de393..941fe76 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -10,7 +10,7 @@ import logging.handlers from importlib import import_module -from saml2 import root_logger +from saml2 import root_logger, BINDING_URI from saml2 import BINDING_SOAP from saml2 import BINDING_HTTP_REDIRECT from saml2 import BINDING_HTTP_POST @@ -137,7 +137,7 @@ PREFERRED_BINDING={ "authn_query_service": [BINDING_SOAP], "attribute_service": [BINDING_SOAP], "authz_service": [BINDING_SOAP], - "assertion_id_request_service": [BINDING_SOAP], + "assertion_id_request_service": [BINDING_URI], "artifact_resolution_service": [BINDING_HTTP_REDIRECT, BINDING_HTTP_POST], "attribute_consuming_service": _RPA } diff --git a/src/saml2/entity.py b/src/saml2/entity.py index e0c8699..03c9f04 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -4,7 +4,7 @@ from hashlib import sha1 from saml2.metadata import ENDPOINTS from saml2.soap import parse_soap_enveloped_saml_artifact_resolve -from saml2 import samlp, saml, response +from saml2 import samlp, saml, response, BINDING_URI from saml2 import request from saml2 import soap from saml2 import element_to_extension_element @@ -120,9 +120,18 @@ class Entity(HTTPBase): return Issuer(text=self.config.entityid, format=NAMEID_FORMAT_ENTITY) - def apply_binding(self, binding, msg_str, destination, relay_state, + def apply_binding(self, binding, msg_str, destination="", relay_state="", typ="SAMLRequest"): + """ + Construct the necessary HTTP arguments dependent on Binding + :param binding: Which binding to use + :param msg_str: The return message as a string (XML) + :param destination: Where to send the message + :param relay_state: Relay_state if provided + :param typ: Which type of message this is + :return: A dictionary + """ if binding == BINDING_HTTP_POST: logger.info("HTTP POST") info = self.use_http_form_post(msg_str, destination, @@ -136,6 +145,8 @@ class Entity(HTTPBase): info["method"] = "GET" elif binding == BINDING_SOAP: info = self.use_soap(msg_str, destination) + elif binding == BINDING_URI: + info = self.use_http_uri(msg_str, typ, destination) else: raise Exception("Unknown binding type: %s" % binding) @@ -227,8 +238,10 @@ class Entity(HTTPBase): elif binding == BINDING_SOAP: func = getattr(soap, "parse_soap_enveloped_saml_%s" % msgtype) xmlstr = func(txt) + elif binding == BINDING_URI: + xmlstr = txt else: - raise ValueError("Don't know how to handle '%s'") + raise ValueError("Don't know how to handle '%s'" % binding) return xmlstr diff --git a/src/saml2/httpbase.py b/src/saml2/httpbase.py index 52520a3..f438e36 100644 --- a/src/saml2/httpbase.py +++ b/src/saml2/httpbase.py @@ -13,7 +13,6 @@ from saml2.pack import make_soap_enveloped_saml_thingy from saml2.pack import http_redirect_message import logging -from saml2.soap import parse_soap_enveloped_saml_response logger = logging.getLogger(__name__) @@ -204,6 +203,29 @@ class HTTPBase(object): return http_redirect_message(message, destination, relay_state, typ) + def use_http_uri(self, message, typ, destination=""): + if typ == "SAMLResponse": + info = { + "data": message.split("\n")[1], + "headers": [ + ("Content-Type", "application/samlassertion+xml"), + ("Cache-Control", "no-cache, no-store"), + ("Pragma", "no-cache") + ] + } + elif typ == "SAMLRequest": + # msg should be an identifier + info = { + "data": "", + "headers": [ + ("Location", "%s?ID=%s" % (destination, message)) + ] + } + else: + raise NotImplemented + + return info + def use_soap(self, request, destination="", headers=None, sign=False): """ Construct the necessary information for using SOAP+POST diff --git a/src/saml2/ident.py b/src/saml2/ident.py index cd76d19..a78048c 100644 --- a/src/saml2/ident.py +++ b/src/saml2/ident.py @@ -8,6 +8,7 @@ from urllib import unquote from saml2.s_utils import rndstr from saml2.s_utils import PolicyError from saml2.saml import NameID +from saml2.saml import NAMEID_FORMAT_PERSISTENT from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.saml import NAMEID_FORMAT_EMAILADDRESS @@ -156,13 +157,21 @@ class IdentDB(object): :param sp_name_qualifier: The 'user'/-s of the name_id :param name_id_policy: The policy the server on the other side wants us to follow. - :param sp_nid: Name ID Formats from the SPs metadata + :param name_qualifier: A domain qualifier :return: NameID instance precursor """ args = self.nim_args(local_policy, sp_name_qualifier, name_id_policy) return self.get_nameid(userid, **args) + def transient_nameid(self, userid, sp_name_qualifier="", name_qualifier=""): + return self.get_nameid(userid, NAMEID_FORMAT_TRANSIENT, + sp_name_qualifier, name_qualifier) + + def permanent_nameid(self, userid, sp_name_qualifier="", name_qualifier=""): + return self.get_nameid(userid, NAMEID_FORMAT_PERSISTENT, + sp_name_qualifier, name_qualifier) + def find_local_id(self, name_id): """ Only find on persistent IDs diff --git a/src/saml2/response.py b/src/saml2/response.py index a864cb0..0f7d846 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -154,17 +154,10 @@ class StatusResponse(object): return self._postamble() def _loads(self, xmldata, decode=True, origxml=None): -# if decode: -# decoded_xml = base64.b64decode(xmldata) -# else: -# decoded_xml = xmldata - + # own copy self.xmlstr = xmldata[:] logger.debug("xmlstr: %s" % (self.xmlstr,)) -# fil = open("response.xml", "w") -# fil.write(self.xmlstr) -# fil.close() try: self.response = self.signature_check(xmldata, origdoc=origxml) @@ -641,20 +634,6 @@ class AuthnResponse(StatusResponse): def __str__(self): return "%s" % self.xmlstr -class AssertionIDResponse(AuthnResponse): - msgtype = "assertion_id_response" - - def __init__(self, sec_context, attribute_converters, entity_id, - return_addr=None, timeslack=0, asynchop=False, test=False): - - AuthnResponse.__init__(self, sec_context, attribute_converters, - entity_id, return_addr, timeslack=timeslack, - asynchop=asynchop, test=test) - self.entity_id = entity_id - self.attribute_converters = attribute_converters - self.assertion = None - self.context = "AssertionIdResponse" - class AuthnQueryResponse(AuthnResponse): msgtype = "authn_query_response" @@ -747,3 +726,62 @@ def response_factory(xmlstr, conf, return_addr=None, return logoutresp return response + +# =========================================================================== +# A class of it's own + +class AssertionIDResponse(object): + msgtype = "assertion_id_response" + + def __init__(self, sec_context, attribute_converters, timeslack=0, + **kwargs): + + self.sec = sec_context + self.timeslack = timeslack + self.xmlstr = "" + self.name_id = "" + self.response = None + self.not_signed = False + self.attribute_converters = attribute_converters + self.assertion = None + self.context = "AssertionIdResponse" + self.signature_check = self.sec.correctly_signed_assertion_id_response + + def loads(self, xmldata, decode=True, origxml=None): + # own copy + self.xmlstr = xmldata[:] + logger.debug("xmlstr: %s" % (self.xmlstr,)) + + try: + self.response = self.signature_check(xmldata, origdoc=origxml) + self.assertion = self.response + except TypeError: + raise + except SignatureError: + raise + except Exception, excp: + logger.exception("EXCEPTION: %s", excp) + raise + + #print "<", self.response + + return self._postamble() + + def verify(self): + try: + valid_instance(self.response) + except NotValid, exc: + logger.error("Not valid response: %s" % exc.args[0]) + raise + return self + + def _postamble(self): + if not self.response: + logger.error("Response was not correctly signed") + if self.xmlstr: + logger.info(self.xmlstr) + raise IncorrectlySigned() + + logger.debug("response: %s" % (self.response,)) + + return self diff --git a/src/saml2/s_utils.py b/src/saml2/s_utils.py index abe1c60..38c9aeb 100644 --- a/src/saml2/s_utils.py +++ b/src/saml2/s_utils.py @@ -29,8 +29,11 @@ logger = logging.getLogger(__name__) class VersionMismatch(Exception): pass - -class UnknownPrincipal(Exception): + +class Unknown(Exception): + pass + +class UnknownPrincipal(Unknown): pass class UnsupportedBinding(Exception): @@ -45,6 +48,10 @@ class MissingValue(Exception): class PolicyError(Exception): pass +class BadRequest(Exception): + pass + + EXCEPTION2STATUS = { VersionMismatch: samlp.STATUS_VERSION_MISMATCH, UnknownPrincipal: samlp.STATUS_UNKNOWN_PRINCIPAL, diff --git a/src/saml2/server.py b/src/saml2/server.py index c28b72f..a4822b2 100644 --- a/src/saml2/server.py +++ b/src/saml2/server.py @@ -39,10 +39,11 @@ from saml2.request import NameIDMappingRequest from saml2.request import AuthzDecisionQuery from saml2.request import AuthnQuery -from saml2.s_utils import MissingValue +from saml2.s_utils import MissingValue, Unknown +from saml2.s_utils import BadRequest from saml2.s_utils import error_status_factory -from saml2.sigver import pre_signature_part +from saml2.sigver import pre_signature_part, signed_instance_factory from saml2.assertion import Assertion from saml2.assertion import Policy @@ -461,9 +462,7 @@ class Server(Entity): return self.create_error_response(in_response_to, destination, sp_entity_id, exc, name_id) - def create_assertion_id_request_response(self, assertion_id, in_response_to, - issuer=None, sign_response=False, - status=None): + def create_assertion_id_request_response(self, assertion_id, sign=False): """ :param assertion_id: @@ -473,23 +472,20 @@ class Server(Entity): :param status: :return: """ - # Done over SOAP - args = {} - to_sign = [] - for aid in assertion_id: - try: - (assertion, to_sign) = self.get_assertion(aid) - to_sign.extend(to_sign) - try: - args["assertion"].append(assertion) - except KeyError: - args["assertion"] = [assertion] - except KeyError: - pass + try: + (assertion, to_sign) = self.get_assertion(assertion_id) + except KeyError: + raise Unknown - return self._response(in_response_to, "", status, issuer, - sign_response, to_sign, **args) + if to_sign: + if assertion.signature is None: + assertion.signature = pre_signature_part(assertion.id, + self.sec.my_cert, 1) + + return signed_instance_factory(assertion, self.sec, to_sign) + else: + return assertion def create_name_id_mapping_response(self, name_id=None, encrypted_id=None, in_response_to=None, diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 2d3065d..b6ef025 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -746,7 +746,11 @@ class SecurityContext(object): :return: """ - _func = getattr(samlp, "%s_from_string" % msgtype) + try: + _func = getattr(samlp, "%s_from_string" % msgtype) + except AttributeError: + _func = getattr(saml, "%s_from_string" % msgtype) + msg = _func(decoded_xml) if not msg: raise TypeError("Not a %s" % msgtype) @@ -839,6 +843,11 @@ class SecurityContext(object): "assertion_id_request", must, origdoc) + def correctly_signed_assertion_id_response(self, decoded_xml, must=False, + origdoc=None): + return self.correctly_signed_message(decoded_xml, "assertion", must, + origdoc) + def correctly_signed_response(self, decoded_xml, must=False, origdoc=None): """ Check if a instance is correctly signed, if we have metadata for the IdP that sent the info use that, if not use the key that are in diff --git a/tests/idp_all_conf.py b/tests/idp_all_conf.py index c6e3b0e..b7f98a4 100644 --- a/tests/idp_all_conf.py +++ b/tests/idp_all_conf.py @@ -1,4 +1,4 @@ -from saml2 import BINDING_SOAP +from saml2 import BINDING_SOAP, BINDING_URI from saml2 import BINDING_HTTP_REDIRECT from saml2 import BINDING_HTTP_POST from saml2 import BINDING_HTTP_ARTIFACT @@ -51,7 +51,7 @@ CONFIG = { ("%s/ars" % BASE, BINDING_SOAP) ], "assertion_id_request_service": [ - ("%s/airs" % BASE, BINDING_SOAP) + ("%s/airs" % BASE, BINDING_URI) ], "authn_query_service": [ ("%s/aqs" % BASE, BINDING_SOAP) diff --git a/tests/test_33_identifier.py b/tests/test_33_identifier.py index d8a266f..15f1d26 100644 --- a/tests/test_33_identifier.py +++ b/tests/test_33_identifier.py @@ -3,7 +3,7 @@ from saml2 import samlp from saml2.saml import NAMEID_FORMAT_PERSISTENT, NAMEID_FORMAT_TRANSIENT from saml2.config import IdPConfig -from saml2.server import Identifier +from saml2.ident import IdentDB from saml2.assertion import Policy def _eq(l1,l2): @@ -54,7 +54,7 @@ NAME_ID_POLICY_2 = """ class TestIdentifier(): def setup_class(self): - self.id = Identifier("subject.db", CONFIG.vorg) + self.id = IdentDB("subject.db", "example.com", "example") def test_persistent_1(self): policy = Policy({ @@ -67,21 +67,18 @@ class TestIdentifier(): } }) - nameid = self.id.construct_nameid(policy, "foobar", - "urn:mace:example.com:sp:1") + nameid = self.id.construct_nameid("foobar", policy, + "urn:mace:example.com:sp:1") - assert _eq(nameid.keys(), ['text', 'sp_provided_id', - 'sp_name_qualifier', 'name_qualifier', 'format']) - assert _eq(nameid.keyswv(), ['format', 'text', 'sp_name_qualifier']) + assert _eq(nameid.keyswv(), ['format', 'text', 'sp_name_qualifier', + 'name_qualifier']) assert nameid.sp_name_qualifier == "urn:mace:example.com:sp:1" assert nameid.format == NAMEID_FORMAT_PERSISTENT - nameid_2 = self.id.construct_nameid(policy, "foobar", - "urn:mace:example.com:sp:1") - - assert nameid != nameid_2 - assert nameid.text == nameid_2.text + id = self.id.find_local_id(nameid) + assert id == "foobar" + def test_transient_1(self): policy = Policy({ "default": { @@ -92,10 +89,11 @@ class TestIdentifier(): } } }) - nameid = self.id.construct_nameid(policy, "foobar", - "urn:mace:example.com:sp:1") + nameid = self.id.construct_nameid("foobar", policy, + "urn:mace:example.com:sp:1") - assert _eq(nameid.keyswv(), ['text', 'format', 'sp_name_qualifier']) + assert _eq(nameid.keyswv(), ['text', 'format', 'sp_name_qualifier', + 'name_qualifier']) assert nameid.format == NAMEID_FORMAT_TRANSIENT def test_vo_1(self): @@ -111,17 +109,16 @@ class TestIdentifier(): name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_1) print name_id_policy - print self.id.voconf - nameid = self.id.construct_nameid(policy, "foobar", - "urn:mace:example.com:sp:1", - {"uid": "foobar01"}, - name_id_policy) + nameid = self.id.construct_nameid("foobar", policy, + 'http://vo.example.org/biomed', + name_id_policy) print nameid - assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format']) + assert _eq(nameid.keyswv(), ['text', 'sp_name_qualifier', 'format', + 'name_qualifier']) assert nameid.sp_name_qualifier == 'http://vo.example.org/biomed' assert nameid.format == 'urn:oid:2.16.756.1.2.5.1.1.1-NameID' - assert nameid.text == "foobar01" + assert nameid.text == "foobar" def test_vo_2(self): policy = Policy({ @@ -136,8 +133,8 @@ class TestIdentifier(): name_id_policy = samlp.name_id_policy_from_string(NAME_ID_POLICY_2) - nameid = self.id.construct_nameid(policy, "foobar", - "urn:mace:example.com:sp:1", + nameid = self.id.construct_nameid("foobar", policy, + "urn:mace:example.com:sp:1", {"uid": "foobar01"}, name_id_policy) diff --git a/tests/test_68_assertion_id.py b/tests/test_68_assertion_id.py index 0bfc383..485fbeb 100644 --- a/tests/test_68_assertion_id.py +++ b/tests/test_68_assertion_id.py @@ -2,9 +2,10 @@ from urlparse import parse_qs from urlparse import urlparse from saml2.samlp import AuthnRequest from saml2.samlp import NameIDPolicy -from saml2.saml import AUTHN_PASSWORD +from saml2.saml import AUTHN_PASSWORD, Assertion from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2 import BINDING_HTTP_POST +from saml2 import BINDING_URI from saml2 import BINDING_SOAP from saml2.client import Saml2Client from saml2.server import Server @@ -13,20 +14,28 @@ __author__ = 'rolandh' TAG1 = "name=\"SAMLRequest\" value=" -def get_msg(hinfo, binding): +def get_msg(hinfo, binding, response=False): if binding == BINDING_SOAP: - xmlstr = hinfo["data"] + msg = hinfo["data"] elif binding == BINDING_HTTP_POST: _inp = hinfo["data"][3] i = _inp.find(TAG1) i += len(TAG1) + 1 j = _inp.find('"', i) - xmlstr = _inp[i:j] + msg = _inp[i:j] + elif binding == BINDING_URI: + if response: + msg = hinfo["data"] + else: + msg = "" + for header, val in hinfo["headers"]: + if header == "Location": + return parse_qs(val.split("?")[1])["ID"][0] else: # BINDING_HTTP_REDIRECT parts = urlparse(hinfo["headers"][0][1]) - xmlstr = parse_qs(parts.query)["SAMLRequest"][0] + msg = parse_qs(parts.query)["SAMLRequest"][0] - return xmlstr + return msg def test_basic_flow(): sp = Saml2Client(config_file="servera_conf") @@ -42,7 +51,8 @@ def test_basic_flow(): # == Create an AuthnRequest response - name_id = idp.ident.transient_nameid(sp.config.entityid, "id12") + name_id = idp.ident.transient_nameid("id12", sp.config.entityid) + binding, destination = idp.pick_binding("assertion_consumer_service", entity_id=sp.config.entityid) resp = idp.create_authn_response({"eduPersonEntitlement": "Short stop", @@ -73,32 +83,23 @@ def test_basic_flow(): binding, destination = sp.pick_binding("assertion_id_request_service", entity_id=idp.config.entityid) - _req = sp.create_assertion_id_request([asid], destination) - - hinfo = sp.apply_binding(binding, "%s" % _req, destination, - "realy_stat") + hinfo = sp.apply_binding(binding, asid, destination) # ---------- @IDP ------------ - xmlstr = get_msg(hinfo, binding) - - rr = idp.parse_assertion_id_request(xmlstr, binding) - - print rr + aid = get_msg(hinfo, binding, response=False) # == construct response - aids = [x.text for x in rr.message.assertion_id_ref] - resp_args = idp.response_args(rr.message) - - resp = idp.create_assertion_id_request_response(aids, **resp_args) + resp = idp.create_assertion_id_request_response(aid) hinfo = idp.apply_binding(binding, "%s" % resp, None, "", "SAMLResponse") # ----------- @SP ------------- - xmlstr = get_msg(hinfo, binding) + xmlstr = get_msg(hinfo, binding, response=True) final = sp.parse_assertion_id_request_response(xmlstr, binding) - print final \ No newline at end of file + print final.response + assert isinstance(final.response, Assertion) \ No newline at end of file