More refactoring

This commit is contained in:
Roland Hedberg
2012-12-30 12:46:36 +01:00
parent b655962793
commit 956ff0a536
9 changed files with 222 additions and 293 deletions

View File

@@ -1,12 +1,12 @@
#!/usr/bin/env python #!/usr/bin/env python
import re import re
import base64
import logging import logging
#from cgi import parse_qs #from cgi import parse_qs
from urlparse import parse_qs from urlparse import parse_qs
from saml2.pack import http_form_post_message from saml2.pack import http_form_post_message
from saml2.s_utils import OtherError
from saml2.saml import AUTHN_PASSWORD from saml2.saml import AUTHN_PASSWORD
from saml2 import server from saml2 import server
from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST from saml2 import BINDING_HTTP_REDIRECT, BINDING_HTTP_POST
@@ -89,33 +89,53 @@ def sso(environ, start_response, user):
return ['Unknown user'] return ['Unknown user']
# base 64 encoded request # base 64 encoded request
req_info = IDP.parse_authn_request(query["SAMLRequest"][0]) # Assume default binding, that is HTTP-redirect
req = IDP.parse_authn_request(query["SAMLRequest"][0])
if req is None:
start_response("500", [('Content-Type', 'text/plain')])
return ["Failed to parse the SAML request"]
logger.info("parsed OK") logger.info("parsed OK")
logger.info("%s" % req_info) logger.info("%s" % req)
identity = dict(environ["repoze.who.identity"]["user"]) identity = dict(environ["repoze.who.identity"]["user"])
logger.info("Identity: %s" % (identity,)) logger.info("Identity: %s" % (identity,))
userid = environ["repoze.who.identity"]['repoze.who.userid'] userid = environ["repoze.who.identity"]['repoze.who.userid']
if REPOZE_ID_EQUIVALENT: if REPOZE_ID_EQUIVALENT:
identity[REPOZE_ID_EQUIVALENT] = userid identity[REPOZE_ID_EQUIVALENT] = userid
# What's the binding ? ProtocolBinding
if req.message.protocol_binding == BINDING_HTTP_REDIRECT:
_binding = BINDING_HTTP_POST
else:
_binding = req.message.protocol_binding
try: try:
authn_resp = IDP.create_authn_response(identity, resp_args = IDP.response_args(req.message, [_binding])
req_info["id"], except Exception:
req_info["consumer_url"],
req_info["sp_entity_id"],
req_info["request"].name_id_policy,
userid,
authn=AUTHN)
except Exception, excp:
if logger: logger.error("Exception: %s" % (excp,))
raise raise
if logger: logger.info("AuthNResponse: %s" % authn_resp) if req.message.assertion_consumer_service_url:
if req.message.assertion_consumer_service_url != resp_args["destination"]:
# serious error on someones behalf
logger.error("%s != %s" % (req.message.assertion_consumer_service_url,
resp_args["destination"]))
raise OtherError("ConsumerURL and return destination mismatch")
headers, response = http_form_post_message(authn_resp, try:
req_info["consumer_url"], "/") authn_resp = IDP.create_authn_response(identity, userid, authn=AUTHN,
start_response('200 OK', headers) **resp_args)
return response except Exception, excp:
logger.error("Exception: %s" % (excp,))
raise
logger.info("AuthNResponse: %s" % authn_resp)
http_args = http_form_post_message(authn_resp, resp_args["destination"],
relay_state=query["RelayState"])
start_response('200 OK', http_args["headers"])
return http_args["data"]
def whoami(environ, start_response, user): def whoami(environ, start_response, user):
start_response('200 OK', [('Content-Type', 'text/html')]) start_response('200 OK', [('Content-Type', 'text/html')])
@@ -165,26 +185,36 @@ def slo(environ, start_response, user):
# look for the subject # look for the subject
subject = req_info.subject_id() subject = req_info.subject_id()
subject = subject.text.strip() subject = subject.text.strip()
sp_entity_id = req_info.message.issuer.text.strip()
logger.info("Logout subject: %s" % (subject,)) logger.info("Logout subject: %s" % (subject,))
logger.info("local identifier: %s" % IDP.ident.local_name(sp_entity_id,
subject))
# remove the authentication
status = None status = None
# Either HTTP-Post or HTTP-redirect is possible # Either HTTP-Post or HTTP-redirect is possible, prefer HTTP-Post.
# Order matters
bindings = [BINDING_HTTP_POST, BINDING_HTTP_REDIRECT] bindings = [BINDING_HTTP_POST, BINDING_HTTP_REDIRECT]
(resp, headers, message) = IDP.create_logout_response(req_info.message, try:
response = IDP.create_logout_response(req_info.message,
bindings) bindings)
#headers.append(session.cookie(expire="now")) binding, destination = IDP.pick_binding(bindings,
logger.info("Response code: %s" % (resp,)) "single_logout_service",
logger.info("Header: %s" % (headers,)) "spsso", response)
http_args = IDP.apply_binding(binding, "%s" % response, destination,
query["RelayState"], "SAMLResponse")
except Exception, exc:
start_response('400 Bad request', [('Content-Type', 'text/plain')])
return ['%s' % exc]
delco = delete_cookie(environ, "pysaml2idp") delco = delete_cookie(environ, "pysaml2idp")
if delco: if delco:
headers.append(delco) http_args["headers"].append(delco)
start_response(resp, headers)
return message if binding == BINDING_HTTP_POST:
start_response("200 OK", http_args["headers"])
else:
start_response("302 Found", http_args["headers"])
return http_args["data"]
def delete_cookie(environ, name): def delete_cookie(environ, name):
kaka = environ.get("HTTP_COOKIE", '') kaka = environ.get("HTTP_COOKIE", '')

View File

@@ -18,6 +18,7 @@
"""Contains classes and functions that a SAML2.0 Service Provider (SP) may use """Contains classes and functions that a SAML2.0 Service Provider (SP) may use
to conclude its tasks. to conclude its tasks.
""" """
from saml2.s_utils import sid
from saml2.samlp import logout_response_from_string from saml2.samlp import logout_response_from_string
import saml2 import saml2
@@ -50,28 +51,6 @@ logger = logging.getLogger(__name__)
class Saml2Client(Base): class Saml2Client(Base):
""" The basic pySAML2 service provider class """ """ The basic pySAML2 service provider class """
def _request_info(self, binding, req_str, destination, relay_state):
if binding == saml2.BINDING_HTTP_POST:
logger.info("HTTP POST")
info = self.use_http_form_post(req_str, destination,
relay_state)
info["url"] = destination
info["method"] = "GET"
elif binding == saml2.BINDING_HTTP_REDIRECT:
logger.info("HTTP REDIRECT")
info = self.use_http_get(req_str, destination,
relay_state)
info["url"] = destination
info["method"] = "GET"
elif binding == BINDING_SOAP:
info = self.use_soap(req_str, destination)
else:
raise Exception("Unknown binding type: %s" % binding)
return info
def prepare_for_authenticate(self, entityid=None, relay_state="", def prepare_for_authenticate(self, entityid=None, relay_state="",
binding=saml2.BINDING_HTTP_REDIRECT, vorg="", binding=saml2.BINDING_HTTP_REDIRECT, vorg="",
nameid_format=NAMEID_FORMAT_PERSISTENT, nameid_format=NAMEID_FORMAT_PERSISTENT,
@@ -100,7 +79,7 @@ class Saml2Client(Base):
logger.info("AuthNReq: %s" % _req_str) logger.info("AuthNReq: %s" % _req_str)
info = self._request_info(binding, _req_str, destination, relay_state) info = self.apply_binding(binding, _req_str, destination, relay_state)
return req.id, info return req.id, info
@@ -156,8 +135,8 @@ class Saml2Client(Base):
for binding in [#BINDING_SOAP, for binding in [#BINDING_SOAP,
BINDING_HTTP_POST, BINDING_HTTP_POST,
BINDING_HTTP_REDIRECT]: BINDING_HTTP_REDIRECT]:
srvs = self.metadata.single_logout_service(entity_id, "idpsso", srvs = self.metadata.single_logout_service(entity_id, binding,
binding=binding) "idpsso")
if not srvs: if not srvs:
continue continue
@@ -185,8 +164,8 @@ class Saml2Client(Base):
srequest = signed_instance_factory(request, self.sec, to_sign) srequest = signed_instance_factory(request, self.sec, to_sign)
relay_state = self._relay_state(request.id) relay_state = self._relay_state(request.id)
http_info = self._request_info(binding, srequest, http_info = self.apply_binding(binding, srequest, destination,
destination, relay_state) relay_state)
if binding == BINDING_SOAP: if binding == BINDING_SOAP:
if response: if response:
@@ -368,20 +347,29 @@ class Saml2Client(Base):
:param real_id: The identifier which is the key to this entity in the :param real_id: The identifier which is the key to this entity in the
identity database identity database
:param binding: Which binding to use :param binding: Which binding to use
:return: The attributes returned :return: The attributes returned if BINDING_SOAP was used.
HTTP args if BINDING_HTT_POST was used.
""" """
srvs = self.metadata.attribute_service(entityid, binding)
if srvs == []:
raise Exception("No attribute service support at entity")
destination = destinations(srvs)[0]
if real_id: if real_id:
response_args = {"real_id": real_id} response_args = {"real_id": real_id}
else: else:
response_args = {} response_args = {}
if not binding:
binding, destination = self.pick_binding([BINDING_SOAP,
BINDING_HTTP_POST],
"attribute_service",
"attribute_authority",
entity_id=entityid)
else:
srvs = self.metadata.attribute_service(entityid, binding)
if srvs is []:
raise Exception("No attribute service support at entity")
destination = destinations(srvs)[0]
if binding == BINDING_SOAP: if binding == BINDING_SOAP:
return self._use_soap(destination, "attribute_query", return self._use_soap(destination, "attribute_query",
consent=consent, extensions=extensions, consent=consent, extensions=extensions,
@@ -392,13 +380,18 @@ class Saml2Client(Base):
nameid_format=nameid_format, nameid_format=nameid_format,
response_args=response_args) response_args=response_args)
elif binding == BINDING_HTTP_POST: elif binding == BINDING_HTTP_POST:
return self._use_soap(destination, "attribute_query", mid = sid()
consent=consent, extensions=extensions, query = self.create_attribute_query(destination, subject_id,
sign=sign, subject_id=subject_id, attribute, sp_name_qualifier,
attribute=attribute, name_qualifier, nameid_format,
sp_name_qualifier=sp_name_qualifier, mid, consent, extensions,
name_qualifier=name_qualifier, sign)
nameid_format=nameid_format, self.state[query.id] = {"entity_id": entityid,
response_args=response_args) "operation": "AttributeQuery",
"subject_id": subject_id,
"sign": sign}
relay_state = self._relay_state(query.id)
return self.apply_binding(binding,"%s" % query, destination,
relay_state)
else: else:
raise Exception("Unsupported binding") raise Exception("Unsupported binding")

View File

@@ -18,8 +18,8 @@
"""Contains classes and functions that a SAML2.0 Service Provider (SP) may use """Contains classes and functions that a SAML2.0 Service Provider (SP) may use
to conclude its tasks. to conclude its tasks.
""" """
from saml2.entity import Entity
from saml2.httpbase import HTTPBase
from saml2.mdstore import destinations from saml2.mdstore import destinations
from saml2.saml import AssertionIDRef, NAMEID_FORMAT_TRANSIENT from saml2.saml import AssertionIDRef, NAMEID_FORMAT_TRANSIENT
from saml2.samlp import AuthnQuery, ArtifactResponse, StatusCode, Status from saml2.samlp import AuthnQuery, ArtifactResponse, StatusCode, Status
@@ -52,10 +52,8 @@ from saml2.s_utils import decode_base64_and_inflate
from saml2 import samlp, saml, class_name from saml2 import samlp, saml, class_name
from saml2 import VERSION from saml2 import VERSION
from saml2.sigver import pre_signature_part from saml2.sigver import pre_signature_part
from saml2.sigver import security_context, signed_instance_factory from saml2.sigver import signed_instance_factory
from saml2.population import Population from saml2.population import Population
from saml2.virtual_org import VirtualOrg
from saml2.config import config_factory
from saml2.response import response_factory, attribute_response from saml2.response import response_factory, attribute_response
from saml2.response import LogoutResponse from saml2.response import LogoutResponse
@@ -92,7 +90,7 @@ class LogoutError(Exception):
class NoServiceDefined(Exception): class NoServiceDefined(Exception):
pass pass
class Base(HTTPBase): class Base(Entity):
""" The basic pySAML2 service provider class """ """ The basic pySAML2 service provider class """
def __init__(self, config=None, identity_cache=None, state_cache=None, def __init__(self, config=None, identity_cache=None, state_cache=None,
@@ -104,6 +102,8 @@ class Base(HTTPBase):
:param virtual_organization: A specific virtual organization :param virtual_organization: A specific virtual organization
""" """
Entity.__init__(self, "sp", config, config_file, virtual_organization)
self.users = Population(identity_cache) self.users = Population(identity_cache)
# for server state storage # for server state storage
@@ -112,39 +112,6 @@ class Base(HTTPBase):
else: else:
self.state = state_cache self.state = state_cache
if config:
self.config = config
elif config_file:
self.config = config_factory("sp", config_file)
else:
raise Exception("Missing configuration")
HTTPBase.__init__(self, self.config.verify_ssl_cert,
self.config.ca_certs, self.config.key_file,
self.config.cert_file)
if self.config.vorg:
for vo in self.config.vorg.values():
vo.sp = self
self.metadata = self.config.metadata
self.config.setup_logger()
# we copy the config.debug variable in an internal
# field for convenience and because we may need to
# change it during the tests
self.debug = self.config.debug
self.sec = security_context(self.config)
if virtual_organization:
if isinstance(virtual_organization, basestring):
self.vorg = self.config.vorg[virtual_organization]
elif isinstance(virtual_organization, VirtualOrg):
self.vorg = virtual_organization
else:
self.vorg = None
for foo in ["allow_unsolicited", "authn_requests_signed", for foo in ["allow_unsolicited", "authn_requests_signed",
"logout_requests_signed"]: "logout_requests_signed"]:
if self.config.getattr("sp", foo) == 'true': if self.config.getattr("sp", foo) == 'true':

View File

@@ -144,7 +144,8 @@ class HTTPBase(object):
return r return r
def use_http_form_post(self, message, destination, relay_state): def use_http_form_post(self, message, destination, relay_state,
typ="SAMLRequest"):
""" """
Return a form that will automagically execute and POST the message Return a form that will automagically execute and POST the message
to the recipient. to the recipient.
@@ -152,14 +153,16 @@ class HTTPBase(object):
:param message: :param message:
:param destination: :param destination:
:param relay_state: :param relay_state:
:return: tuple (header, message) :param typ: Whether a Request, Response or Artifact
:return: dictionary
""" """
if not isinstance(message, basestring): if not isinstance(message, basestring):
request = "%s" % (message,) request = "%s" % (message,)
return http_form_post_message(message, destination, relay_state) return http_form_post_message(message, destination, relay_state, typ)
def use_http_get(self, message, destination, relay_state): def use_http_get(self, message, destination, relay_state,
typ="SAMLRequest"):
""" """
Send a message using GET, this is the HTTP-Redirect case so Send a message using GET, this is the HTTP-Redirect case so
no direct response is expected to this request. no direct response is expected to this request.
@@ -167,12 +170,13 @@ class HTTPBase(object):
:param message: :param message:
:param destination: :param destination:
:param relay_state: :param relay_state:
:return: tuple (header, None) :param typ: Whether a Request, Response or Artifact
:return: dictionary
""" """
if not isinstance(message, basestring): if not isinstance(message, basestring):
request = "%s" % (message,) request = "%s" % (message,)
return http_redirect_message(message, destination, relay_state) return http_redirect_message(message, destination, relay_state, typ)
def use_soap(self, request, destination, headers=None, sign=False): def use_soap(self, request, destination, headers=None, sign=False):
""" """
@@ -182,7 +186,7 @@ class HTTPBase(object):
:param destination: :param destination:
:param headers: :param headers:
:param sign: :param sign:
:return: :return: dictionary
""" """
if headers is None: if headers is None:
headers = {"content-type": "application/soap+xml"} headers = {"content-type": "application/soap+xml"}
@@ -224,4 +228,3 @@ class HTTPBase(object):
return parse_soap_enveloped_saml_response(response) return parse_soap_enveloped_saml_response(response)
else: else:
return False return False

View File

@@ -306,7 +306,7 @@ class MetadataStore(object):
return srvs return srvs
return [] return []
def single_sign_on_service(self, entity_id, binding=None): def single_sign_on_service(self, entity_id, binding=None, typ="idpsso"):
# IDP # IDP
if binding is None: if binding is None:
@@ -314,70 +314,76 @@ class MetadataStore(object):
return self._service(entity_id, "idpsso_descriptor", return self._service(entity_id, "idpsso_descriptor",
"single_sign_on_service", binding) "single_sign_on_service", binding)
def name_id_mapping_service(self, entity_id, binding=None): def name_id_mapping_service(self, entity_id, binding=None, typ="idpsso"):
# IDP # IDP
if binding is None: if binding is None:
binding = BINDING_HTTP_REDIRECT binding = BINDING_HTTP_REDIRECT
return self._service(entity_id, "idpsso_descriptor", return self._service(entity_id, "idpsso_descriptor",
"name_id_mapping_service", binding) "name_id_mapping_service", binding)
def authn_query_service(self, entity_id, binding=None): def authn_query_service(self, entity_id, binding=None,
typ="authn_authority"):
# AuthnAuthority # AuthnAuthority
if binding is None: if binding is None:
binding = BINDING_SOAP binding = BINDING_SOAP
return self._service(entity_id, "authn_authority_descriptor", return self._service(entity_id, "authn_authority_descriptor",
"authn_query_service", binding) "authn_query_service", binding)
def attribute_service(self, entity_id, binding=None): def attribute_service(self, entity_id, binding=None,
typ="attribute_authority"):
# AttributeAuthority # AttributeAuthority
if binding is None: if binding is None:
binding = BINDING_HTTP_REDIRECT binding = BINDING_HTTP_REDIRECT
return self._service(entity_id, "attribute_authority_descriptor", return self._service(entity_id, "attribute_authority_descriptor",
"attribute_service", binding) "attribute_service", binding)
def authz_service(self, entity_id, binding=None): def authz_service(self, entity_id, binding=None, typ="pdp"):
# PDP # PDP
if binding is None: if binding is None:
binding = BINDING_SOAP binding = BINDING_SOAP
return self._service(entity_id, "pdp_descriptor", return self._service(entity_id, "pdp_descriptor",
"authz_service", binding) "authz_service", binding)
def assertion_id_request_service(self, entity_id, typ, binding=None): def assertion_id_request_service(self, entity_id, binding=None, typ=None):
# AuthnAuthority + IDP + PDP + AttributeAuthority # AuthnAuthority + IDP + PDP + AttributeAuthority
if typ is None:
raise AttributeError("Missing type specification")
if binding is None: if binding is None:
binding = BINDING_SOAP binding = BINDING_SOAP
return self._service(entity_id, "%s_descriptor" % typ, return self._service(entity_id, "%s_descriptor" % typ,
"assertion_id_request_service", binding) "assertion_id_request_service", binding)
def single_logout_service(self, entity_id, typ, binding=None): def single_logout_service(self, entity_id, binding=None, typ=None):
# IDP + SP # IDP + SP
if typ is None:
raise AttributeError("Missing type specification")
if binding is None: if binding is None:
binding = BINDING_HTTP_REDIRECT binding = BINDING_HTTP_REDIRECT
return self._service(entity_id, "%s_descriptor" % typ, return self._service(entity_id, "%s_descriptor" % typ,
"single_logout_service", binding) "single_logout_service", binding)
def manage_name_id_service(self, entity_id, typ, binding=None): def manage_name_id_service(self, entity_id, binding=None, typ=None):
# IDP + SP # IDP + SP
if binding is None: if binding is None:
binding = BINDING_HTTP_REDIRECT binding = BINDING_HTTP_REDIRECT
return self._service(entity_id, "%s_descriptor" % typ, return self._service(entity_id, "%s_descriptor" % typ,
"manage_name_id_service", binding) "manage_name_id_service", binding)
def artifact_resolution_service(self, entity_id, typ, binding=None): def artifact_resolution_service(self, entity_id, binding=None, typ=None):
# IDP + SP # IDP + SP
if binding is None: if binding is None:
binding = BINDING_HTTP_REDIRECT binding = BINDING_HTTP_REDIRECT
return self._service(entity_id, "%s_descriptor" % typ, return self._service(entity_id, "%s_descriptor" % typ,
"artifact_resolution_service", binding) "artifact_resolution_service", binding)
def assertion_consumer_service(self, entity_id, binding=None): def assertion_consumer_service(self, entity_id, binding=None, typ="spsso"):
# SP # SP
if binding is None: if binding is None:
binding = BINDING_HTTP_POST binding = BINDING_HTTP_POST
return self._service(entity_id, "spsso_descriptor", return self._service(entity_id, "spsso_descriptor",
"assertion_consumer_service", binding) "assertion_consumer_service", binding)
def attribute_consuming_service(self, entity_id, binding=None): def attribute_consuming_service(self, entity_id, binding=None, typ="spsso"):
# SP # SP
if binding is None: if binding is None:
binding = BINDING_HTTP_REDIRECT binding = BINDING_HTTP_REDIRECT

View File

@@ -23,15 +23,14 @@ import logging
import shelve import shelve
import sys import sys
import memcache import memcache
from saml2.httpbase import HTTPBase from saml2.entity import Entity
from saml2.mdstore import destinations from saml2.samlp import LogoutResponse
from saml2 import saml, BINDING_HTTP_POST from saml2 import saml, VERSION
from saml2 import class_name from saml2 import class_name
from saml2 import soap from saml2 import soap
from saml2 import BINDING_HTTP_REDIRECT from saml2 import BINDING_HTTP_REDIRECT
from saml2 import BINDING_SOAP from saml2 import BINDING_SOAP
from saml2 import BINDING_PAOS
from saml2.request import AuthnRequest from saml2.request import AuthnRequest
from saml2.request import AttributeQuery from saml2.request import AttributeQuery
@@ -40,19 +39,13 @@ from saml2.request import LogoutRequest
from saml2.s_utils import sid from saml2.s_utils import sid
from saml2.s_utils import MissingValue from saml2.s_utils import MissingValue
from saml2.s_utils import success_status_factory from saml2.s_utils import success_status_factory
from saml2.s_utils import OtherError
from saml2.s_utils import UnknownPrincipal
from saml2.s_utils import UnsupportedBinding
from saml2.s_utils import error_status_factory from saml2.s_utils import error_status_factory
from saml2.time_util import instant from saml2.time_util import instant
from saml2.sigver import security_context
from saml2.sigver import signed_instance_factory from saml2.sigver import signed_instance_factory
from saml2.sigver import pre_signature_part from saml2.sigver import pre_signature_part
from saml2.sigver import response_factory, logoutresponse_factory from saml2.sigver import response_factory
from saml2.config import config_factory
from saml2.assertion import Assertion, Policy, restriction_from_attribute_spec, filter_attribute_value_assertions from saml2.assertion import Assertion, Policy, restriction_from_attribute_spec, filter_attribute_value_assertions
@@ -218,50 +211,25 @@ class Identifier(object):
except KeyError: except KeyError:
return None return None
class Server(HTTPBase): class Server(Entity):
""" A class that does things that IdPs or AAs do """ """ A class that does things that IdPs or AAs do """
def __init__(self, config_file="", config=None, _cache="", stype="idp"): def __init__(self, config_file="", config=None, _cache="", stype="idp"):
Entity.__init__(self, stype, config, config_file)
self.ident = None self.init_config(stype)
if config_file:
self.load_config(config_file, stype)
elif config:
self.conf = config
else:
raise Exception("Missing configuration")
HTTPBase.__init__(self, self.conf.verify_ssl_cert,
self.conf.ca_certs, self.conf.key_file,
self.conf.cert_file)
self.conf.setup_logger()
self.metadata = self.conf.metadata
self.sec = security_context(self.conf)
self._cache = _cache self._cache = _cache
# if cache: def init_config(self, stype="idp"):
# if isinstance(cache, basestring): """ Remaining init of the server configuration
# self.cache = Cache(cache)
# else:
# self.cache = cache
# else:
# self.cache = Cache()
def load_config(self, config_file, stype="idp"):
""" Load the server configuration
:param config_file: The name of the configuration file
:param stype: The type of Server ("idp"/"aa") :param stype: The type of Server ("idp"/"aa")
""" """
self.conf = config_factory(stype, config_file)
if stype == "aa": if stype == "aa":
return return
try: try:
# subject information is stored in a database # subject information is stored in a database
# default database is a shelve database which is OK in some setups # default database is a shelve database which is OK in some setups
dbspec = self.conf.getattr("subject_data", "idp") dbspec = self.config.getattr("subject_data", "idp")
idb = None idb = None
if isinstance(dbspec, basestring): if isinstance(dbspec, basestring):
idb = shelve.open(dbspec, writeback=True) idb = shelve.open(dbspec, writeback=True)
@@ -276,7 +244,7 @@ class Server(HTTPBase):
idb = addr idb = addr
if idb is not None: if idb is not None:
self.ident = Identifier(idb, self.conf.virtual_organization) self.ident = Identifier(idb, self.config.virtual_organization)
else: else:
raise Exception("Couldn't open identity database: %s" % raise Exception("Couldn't open identity database: %s" %
(dbspec,)) (dbspec,))
@@ -294,7 +262,7 @@ class Server(HTTPBase):
return saml.Issuer(text=entityid, return saml.Issuer(text=entityid,
format=saml.NAMEID_FORMAT_ENTITY) format=saml.NAMEID_FORMAT_ENTITY)
else: else:
return saml.Issuer(text=self.conf.entityid, return saml.Issuer(text=self.config.entityid,
format=saml.NAMEID_FORMAT_ENTITY) format=saml.NAMEID_FORMAT_ENTITY)
def parse_authn_request(self, enc_request, binding=BINDING_HTTP_REDIRECT): def parse_authn_request(self, enc_request, binding=BINDING_HTTP_REDIRECT):
@@ -315,81 +283,35 @@ class Server(HTTPBase):
_log_debug = logger.debug _log_debug = logger.debug
# The addresses I should receive messages like this on # The addresses I should receive messages like this on
receiver_addresses = self.conf.endpoint("single_sign_on_service", receiver_addresses = self.config.endpoint("single_sign_on_service",
binding) binding)
_log_info("receiver addresses: %s" % receiver_addresses) _log_info("receiver addresses: %s" % receiver_addresses)
_log_info("Binding: %s" % binding) _log_info("Binding: %s" % binding)
try: try:
timeslack = self.conf.accepted_time_diff timeslack = self.config.accepted_time_diff
if not timeslack: if not timeslack:
timeslack = 0 timeslack = 0
except AttributeError: except AttributeError:
timeslack = 0 timeslack = 0
authn_request = AuthnRequest(self.sec, authn_request = AuthnRequest(self.sec,
self.conf.attribute_converters, self.config.attribute_converters,
receiver_addresses, timeslack=timeslack) receiver_addresses, timeslack=timeslack)
if binding == BINDING_SOAP or binding == BINDING_PAOS:
# not base64 decoding and unzipping
authn_request.debug=True
authn_request = authn_request.loads(enc_request, binding)
else:
authn_request = authn_request.loads(enc_request, binding) authn_request = authn_request.loads(enc_request, binding)
_log_debug("Loaded authn_request") _log_debug("Loaded authn_request")
if authn_request: if authn_request:
authn_request = authn_request.verify() authn_request = authn_request.verify()
_log_debug("Verified authn_request") _log_debug("Verified authn_request")
if not authn_request: if not authn_request:
return None return None
response["id"] = authn_request.message.id # put in in_reply_to
sp_entity_id = authn_request.message.issuer.text
# try to find return address in metadata
# What's the binding ? ProtocolBinding
if authn_request.message.protocol_binding == BINDING_HTTP_REDIRECT:
_binding = BINDING_HTTP_POST
else: else:
_binding = authn_request.message.protocol_binding return authn_request
try:
srvs = self.metadata.assertion_consumer_service(sp_entity_id,
binding=_binding)
consumer_url = destinations(srvs)[0]
except (KeyError, IndexError):
_log_info("Failed to find consumer URL for %s" % sp_entity_id)
_log_info("Binding: %s" % _binding)
_log_info("entities: %s" % self.metadata.keys())
raise UnknownPrincipal(sp_entity_id)
if not consumer_url: # what to do ?
_log_info("Couldn't find a consumer URL binding=%s entity_id=%s" % (
_binding,sp_entity_id))
raise UnsupportedBinding(sp_entity_id)
response["sp_entity_id"] = sp_entity_id
response["binding"] = _binding
if authn_request.message.assertion_consumer_service_url:
return_destination = \
authn_request.message.assertion_consumer_service_url
if consumer_url != return_destination:
# serious error on someones behalf
_log_info("%s != %s" % (consumer_url, return_destination))
raise OtherError("ConsumerURL and return destination mismatch")
response["consumer_url"] = consumer_url
response["request"] = authn_request.message
return response
def wants(self, sp_entity_id, index=None): def wants(self, sp_entity_id, index=None):
""" Returns what attributes the SP requires and which are optional """ Returns what attributes the SP requires and which are optional
@@ -411,7 +333,7 @@ class Server(HTTPBase):
attribute - which attributes that the requestor wants back attribute - which attributes that the requestor wants back
query - the whole query query - the whole query
""" """
receiver_addresses = self.conf.endpoint("attribute_service") receiver_addresses = self.config.endpoint("attribute_service")
attribute_query = AttributeQuery( self.sec, receiver_addresses) attribute_query = AttributeQuery( self.sec, receiver_addresses)
attribute_query = attribute_query.loads(xml_string, binding) attribute_query = attribute_query.loads(xml_string, binding)
@@ -509,20 +431,20 @@ class Server(HTTPBase):
(authn_class, authn_authn) = authn (authn_class, authn_authn) = authn
assertion = ast.construct(sp_entity_id, in_response_to, assertion = ast.construct(sp_entity_id, in_response_to,
consumer_url, name_id, consumer_url, name_id,
self.conf.attribute_converters, self.config.attribute_converters,
policy, issuer=_issuer, policy, issuer=_issuer,
authn_class=authn_class, authn_class=authn_class,
authn_auth=authn_authn) authn_auth=authn_authn)
elif authn_decl: elif authn_decl:
assertion = ast.construct(sp_entity_id, in_response_to, assertion = ast.construct(sp_entity_id, in_response_to,
consumer_url, name_id, consumer_url, name_id,
self.conf.attribute_converters, self.config.attribute_converters,
policy, issuer=_issuer, policy, issuer=_issuer,
authn_decl=authn_decl) authn_decl=authn_decl)
else: else:
assertion = ast.construct(sp_entity_id, in_response_to, assertion = ast.construct(sp_entity_id, in_response_to,
consumer_url, name_id, consumer_url, name_id,
self.conf.attribute_converters, self.config.attribute_converters,
policy, issuer=_issuer) policy, issuer=_issuer)
if sign_assertion: if sign_assertion:
@@ -586,7 +508,7 @@ class Server(HTTPBase):
""" """
if not name_id and userid: if not name_id and userid:
try: try:
name_id = self.ident.construct_nameid(self.conf.policy, userid, name_id = self.ident.construct_nameid(self.config.policy, userid,
sp_entity_id, identity) sp_entity_id, identity)
logger.warning("Unspecified NameID format") logger.warning("Unspecified NameID format")
except Exception: except Exception:
@@ -597,7 +519,7 @@ class Server(HTTPBase):
if identity: if identity:
_issuer = self.issuer(issuer) _issuer = self.issuer(issuer)
ast = Assertion(identity) ast = Assertion(identity)
policy = self.conf.getattr("policy", "aa") policy = self.config.getattr("policy", "aa")
if policy: if policy:
ast.apply_policy(sp_entity_id, policy) ast.apply_policy(sp_entity_id, policy)
else: else:
@@ -609,7 +531,7 @@ class Server(HTTPBase):
assertion = ast.construct(sp_entity_id, in_response_to, assertion = ast.construct(sp_entity_id, in_response_to,
consumer_url, name_id, consumer_url, name_id,
self.conf.attribute_converters, self.config.attribute_converters,
policy, issuer=_issuer) policy, issuer=_issuer)
if sign_assertion: if sign_assertion:
@@ -648,7 +570,7 @@ class Server(HTTPBase):
:return: A response instance :return: A response instance
""" """
policy = self.conf.getattr("policy", "idp") policy = self.config.getattr("policy", "idp")
if not name_id: if not name_id:
try: try:
@@ -699,9 +621,9 @@ class Server(HTTPBase):
""" """
try: try:
slo = self.conf.endpoint("single_logout_service", binding, "idp") slo = self.config.endpoint("single_logout_service", binding, "idp")
except IndexError: except IndexError:
logger.info("enpoints: %s" % self.conf.getattr("endpoints", "idp")) logger.info("enpoints: %s" % self.config.getattr("endpoints", "idp"))
logger.info("binding wanted: %s" % (binding,)) logger.info("binding wanted: %s" % (binding,))
raise raise
@@ -733,48 +655,50 @@ class Server(HTTPBase):
return req return req
def create_logout_response(self, request, binding, status=None, def _status_response(self, response_class, issuer, status, sign=False,
sign=False, issuer=None): **kwargs):
""" Create a LogoutResponse. What is returned depends on which binding """ Create a StatusResponse.
is used.
:param request: The request this is a response to :param response_class: Which subclass of StatusResponse that should be
:param binding: Which binding the request came in over used
:param issuer: The issuer of the response message
:param status: The return status of the response operation :param status: The return status of the response operation
:param issuer: The issuer of the message :param sign: Whether the response should be signed or not
:return: A logout message. :param kwargs: Extra arguments to the response class
:return: Class instance or string representation of the instance
""" """
mid = sid() mid = sid()
if not status: if not status:
status = success_status_factory() status = success_status_factory()
# response and packaging differs depending on binding response = response_class(issuer=issuer, id=mid, version=VERSION,
response = "" issue_instant=instant(),
if binding in [BINDING_SOAP, BINDING_HTTP_POST]: status=status, **kwargs)
response = logoutresponse_factory(sign=sign, id = mid,
in_response_to = request.id,
status = status)
elif binding == BINDING_HTTP_REDIRECT:
sp_entity_id = request.issuer.text.strip()
srvs = self.metadata.single_logout_service(sp_entity_id, "spsso")
if not srvs:
raise Exception("Nowhere to send the response")
destination = destinations(srvs)[0]
_issuer = self.issuer(issuer)
response = logoutresponse_factory(sign=sign, id = mid,
in_response_to = request.id,
status = status,
issuer = _issuer,
destination = destination,
sp_entity_id = sp_entity_id,
instant=instant())
if sign: if sign:
response.signature = pre_signature_part(mid)
to_sign = [(class_name(response), mid)] to_sign = [(class_name(response), mid)]
response = signed_instance_factory(response, self.sec, to_sign) response = signed_instance_factory(response, self.sec, to_sign)
return response
def create_logout_response(self, request, bindings, status=None,
sign=False, issuer=None):
""" Create a LogoutResponse.
:param request: The request this is a response to
:param bindings: Which bindings that can be used for the response
:param status: The return status of the response operation
:param issuer: The issuer of the message
:return: HTTP args
"""
rinfo = self.response_args(request, bindings, descr_type="spsso")
response = self._status_response(LogoutResponse, issuer, status,
sign=False, **rinfo)
logger.info("Response: %s" % (response,)) logger.info("Response: %s" % (response,))
return response return response
@@ -788,7 +712,7 @@ class Server(HTTPBase):
attribute - which attributes that the requestor wants back attribute - which attributes that the requestor wants back
query - the whole query query - the whole query
""" """
receiver_addresses = self.conf.endpoint("attribute_service", "idp") receiver_addresses = self.config.endpoint("attribute_service", "idp")
attribute_query = AttributeQuery( self.sec, receiver_addresses) attribute_query = AttributeQuery( self.sec, receiver_addresses)
attribute_query = attribute_query.loads(xml_string, binding) attribute_query = attribute_query.loads(xml_string, binding)

View File

@@ -1054,19 +1054,19 @@ def pre_signature_part(ident, public_key=None, identifier=None):
return signature return signature
def logoutresponse_factory(sign=False, encrypt=False, **kwargs): #def logoutresponse_factory(sign=False, encrypt=False, **kwargs):
response = samlp.LogoutResponse(id=sid(), version=VERSION, # response = samlp.LogoutResponse(id=sid(), version=VERSION,
issue_instant=instant()) # issue_instant=instant())
#
if sign: # if sign:
response.signature = pre_signature_part(kwargs["id"]) # response.signature = pre_signature_part(kwargs["id"])
if encrypt: # if encrypt:
pass # pass
#
for key, val in kwargs.items(): # for key, val in kwargs.items():
setattr(response, key, val) # setattr(response, key, val)
#
return response # return response
def response_factory(sign=False, encrypt=False, **kwargs): def response_factory(sign=False, encrypt=False, **kwargs):
response = samlp.Response(id=sid(), version=VERSION, response = samlp.Response(id=sid(), version=VERSION,

View File

@@ -154,8 +154,7 @@ def test_ext_2():
ents = mds.with_descriptor("spsso") ents = mds.with_descriptor("spsso")
for binding in [BINDING_SOAP, BINDING_HTTP_POST, BINDING_HTTP_ARTIFACT, for binding in [BINDING_SOAP, BINDING_HTTP_POST, BINDING_HTTP_ARTIFACT,
BINDING_HTTP_REDIRECT]: BINDING_HTTP_REDIRECT]:
assert mds.single_logout_service(ents.keys()[0], "spsso", assert mds.single_logout_service(ents.keys()[0], binding, "spsso")
binding=binding)
def test_example(): def test_example():
mds = MetadataStore(ONTS.values(), ATTRCONV, xmlsec_path, mds = MetadataStore(ONTS.values(), ATTRCONV, xmlsec_path,

View File

@@ -76,7 +76,7 @@ class TestServer1():
assert isinstance(issuer, saml.Issuer) assert isinstance(issuer, saml.Issuer)
assert _eq(issuer.keyswv(), ["text","format"]) assert _eq(issuer.keyswv(), ["text","format"])
assert issuer.format == saml.NAMEID_FORMAT_ENTITY assert issuer.format == saml.NAMEID_FORMAT_ENTITY
assert issuer.text == self.server.conf.entityid assert issuer.text == self.server.config.entityid
def test_assertion(self): def test_assertion(self):
@@ -184,15 +184,17 @@ class TestServer1():
print authn_request print authn_request
intermed = s_utils.deflate_and_base64_encode("%s" % authn_request) intermed = s_utils.deflate_and_base64_encode("%s" % authn_request)
response = self.server.parse_authn_request(intermed) req = self.server.parse_authn_request(intermed)
# returns a dictionary # returns a dictionary
print response print req
assert response["consumer_url"] == "http://lingon.catalogix.se:8087/" resp_args = self.server.response_args(req.message, [BINDING_HTTP_POST],
assert response["id"] == "id1" descr_type="spsso")
name_id_policy = response["request"].name_id_policy assert resp_args["destination"] == "http://lingon.catalogix.se:8087/"
assert resp_args["in_response_to"] == "id1"
name_id_policy = resp_args["name_id_policy"]
assert _eq(name_id_policy.keyswv(), ["format", "allow_create"]) assert _eq(name_id_policy.keyswv(), ["format", "allow_create"])
assert name_id_policy.format == saml.NAMEID_FORMAT_TRANSIENT assert name_id_policy.format == saml.NAMEID_FORMAT_TRANSIENT
assert response["sp_entity_id"] == "urn:mace:example.com:saml:roland:sp" assert resp_args["sp_entity_id"] == "urn:mace:example.com:saml:roland:sp"
def test_sso_response_with_identity(self): def test_sso_response_with_identity(self):
name_id = self.server.ident.transient_nameid( name_id = self.server.ident.transient_nameid(
@@ -423,7 +425,7 @@ class TestServer2():
self.server.close_shelve_db() self.server.close_shelve_db()
def test_do_aa_reponse(self): def test_do_aa_reponse(self):
aa_policy = self.server.conf.getattr("policy", "idp") aa_policy = self.server.config.getattr("policy", "idp")
print aa_policy.__dict__ print aa_policy.__dict__
response = self.server.create_aa_response("aaa", response = self.server.create_aa_response("aaa",
"http://example.com/sp/", "http://example.com/sp/",
@@ -474,10 +476,15 @@ class TestServerLogout():
server = Server("idp_slo_redirect_conf") server = Server("idp_slo_redirect_conf")
request = _logout_request("sp_slo_redirect_conf") request = _logout_request("sp_slo_redirect_conf")
print request print request
binding = BINDING_HTTP_REDIRECT bindings = [BINDING_HTTP_REDIRECT]
response = server.create_logout_response(request, binding) response = server.create_logout_response(request, bindings)
http_args = server.use_http_get(response, response.destination, binding, destination = server.pick_binding(bindings,
"/relay_state") "single_logout_service",
assert len(http_args) == 2 "spsso", request)
http_args = server.apply_binding(binding, "%s" % response, destination,
"relay_state", "SAMLResponse")
assert len(http_args) == 4
assert http_args["headers"][0][0] == "Location" assert http_args["headers"][0][0] == "Location"
assert http_args["data"] == [''] assert http_args["data"] == ['']