Effects of refactoring.

This commit is contained in:
Roland Hedberg
2010-07-21 13:13:12 +02:00
parent 046204bf6b
commit 73b55e082f
5 changed files with 566 additions and 0 deletions

153
src/saml2/binding.py Normal file
View File

@@ -0,0 +1,153 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Umeå University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains classes and functions that are necessary to implement
different bindings.
Bindings normally consists of three parts:
- rules about what to send
- how to package the information
- which protocol to use
"""
import httplib2
try:
from xml.etree import cElementTree as ElementTree
except ImportError:
try:
import cElementTree as ElementTree
except ImportError:
from elementtree import ElementTree
from saml2.samlp import NAMESPACE as SAMLP_NAMESPACE
import saml2
NAMESPACE = "http://schemas.xmlsoap.org/soap/envelope/"
def http_post(authn_request, sp_entity_id=None, relay_state=None):
response = []
response.append("<head>")
response.append("""<title>SAML 2.0 POST</title>""")
response.append("</head><body>")
#login_url = location + '?spentityid=' + "lingon.catalogix.se"
response.append(FORM_SPEC % (location, base64.b64encode(authen_req),
os.environ['REQUEST_URI']))
response.append("""<script type="text/javascript">""")
response.append(" window.onload = function ()")
response.append(" { document.forms[0].submit(); ")
response.append("""</script>""")
response.append("</body>")
return ([], response)
def http_redirect(authn_request, sp_entity_id, relay_state):
lista = ["SAMLRequest=%s" % urllib.quote_plus(
deflate_and_base64_encode(
authen_req)),
"spentityid=%s" % sp_entity_id]
if relay_state:
lista.append("RelayState=%s" % relay_state)
login_url = "?".join([location, "&".join(lista)])
headers = [('Location', login_url)]
response = []
return (headers, response)
def make_soap_enveloped_saml_thingy(thingy, header_parts=None):
""" Returns a soap envelope containing a SAML request
as a text string.
:param thingy: The SAML thingy
:return: The SOAP envelope as a string
"""
envelope = ElementTree.Element('')
envelope.tag = '{%s}Envelope' % NAMESPACE
if header_parts:
header = ElementTree.Element('')
header.tag = '{%s}Header' % NAMESPACE
envelope.append(header)
for part in header_parts:
part.become_child_element_of(header)
body = ElementTree.Element('')
body.tag = '{%s}Body' % NAMESPACE
envelope.append(body)
thingy.become_child_element_of(body)
return ElementTree.tostring(envelope, encoding="UTF-8")
def http_soap(authn_request, sp_entity_id, relay_state):
return ({"content-type": "application/soap+xml"},
make_soap_enveloped_saml_thingy(authn_request))
def http_paos(authn_request, sp_entity_id, relay_state, extra=None):
return ({"content-type": "application/soap+xml"},
make_soap_enveloped_saml_thingy(authn_request, extra))
def parse_soap_enveloped_saml(text, body_class, header_class=None):
"""Parses a SOAP enveloped SAML thing and returns header parts and body
:param text: The SOAP object as XML
:return: header parts and body as saml.samlbase instances
"""
envelope = ElementTree.fromstring(text)
assert envelope.tag == '{%s}Envelope' % NAMESPACE
print len(envelope)
body = None
header = {}
for part in envelope:
print ">",part.tag
if part.tag == '{%s}Body' % NAMESPACE:
for sub in part:
try:
body = saml2.create_class_from_element_tree(body_class, sub)
except Exception, exc:
print exc
print body_class.c_tag
raise Exception(
"Wrong body type (%s) in SOAP envelope" % sub.tag)
elif part.tag == '{%s}Header' % NAMESPACE:
if not header_class:
raise Exception("Header where I didn't expect one")
print "--- HEADER ---"
for sub in part:
print ">>",sub.tag
for klass in header_class:
print "?{%s}%s" % (klass.c_namespace,klass.c_tag)
if sub.tag == "{%s}%s" % (klass.c_namespace,klass.c_tag):
header[sub.tag] = \
saml2.create_class_from_element_tree(klass, sub)
break
return body, header
# -----------------------------------------------------------------------------
PACKING = {
saml2.BINDING_HTTP_REDIRECT: http_redirect,
saml2.BINDING_HTTP_POST: http_post,
}
def packager( identifier ):
try:
return PACKING[identifier]
except KeyError:
raise Exception("Unkown binding type: %s" % binding)

32
src/saml2/idpdisc.py Normal file
View File

@@ -0,0 +1,32 @@
#!/usr/bin/env python
#
# Generated Mon Jul 12 22:05:35 2010 by parse_xsd.py version 0.2.
#
import saml2
from saml2 import SamlBase
from saml2 import md
NAMESPACE = "urn:oasis:names:tc:SAML:profiles:SSO:idp-discovery-protocol"
class DiscoveryResponse(md.IndexedEndpointType):
"""The idpdisc:DiscoveryResponse element"""
c_tag = 'DiscoveryResponse'
c_namespace = NAMESPACE
def discovery_response_from_string(xml_string):
""" Create DiscoveryResponse instance from an XML string """
return saml2.create_class_from_xml_string(DiscoveryResponse, xml_string)
ELEMENT_FROM_STRING = {
DiscoveryResponse.c_tag: discovery_response_from_string,
}
ELEMENT_BY_TAG = {
'DiscoveryResponse': DiscoveryResponse,
}
def factory(tag, **kwargs):
return ELEMENT_BY_TAG[tag](**kwargs)

50
src/saml2/population.py Normal file
View File

@@ -0,0 +1,50 @@
from saml2.cache import Cache
class Population(object):
def __init__(self, cache=None):
if cache:
self.cache = Cache(cache)
else:
self.cache = Cache()
def add_information_about_person(self, session_info):
"""If there already are information from this source in the cache
this function will overwrite that information"""
name_id = session_info["name_id"]
issuer = session_info["issuer"]
del session_info["issuer"]
self.cache.set(name_id, issuer, session_info,
session_info["not_on_or_after"])
return name_id
def stale_sources_for_person(self, subject_id, sources=None):
if not sources: # assume that all the members has be asked
# once before, hence they are represented in the cache
sources = self.cache.entities(subject_id)
sources = [m for m in sources \
if not self.cache.active(subject_id, m)]
return sources
def issuers_of_info(self, subject_id):
return self.cache.entities(subject_id)
def get_identity(self, subject_id):
return self.cache.get_identity(subject_id)
def get_info_from(self, subject_id, entity_id):
return self.cache.get(subject_id, entity_id)
def subjects(self):
"""Returns the name id's for all the persons in the cache"""
return self.cache.subjects();
def remove_person(self, subject_id):
self.cache.delete(subject_id)
def get_entityid(self, subject_id, source_id):
try:
return self.cache.get(subject_id, source_id)["name_id"]
except (KeyError, ValueError):
return ""

271
src/saml2/s_utils.py Normal file
View File

@@ -0,0 +1,271 @@
#!/usr/bin/env python
import time
import base64
from saml2 import saml, samlp, VERSION, sigver
from saml2.time_util import instant
try:
from hashlib import md5
except ImportError:
from md5 import md5
import zlib
class VersionMismatch(Exception):
pass
class UnknownPrincipal(Exception):
pass
class UnsupportedBinding(Exception):
pass
class OtherError(Exception):
pass
class MissingValue(Exception):
pass
EXCEPTION2STATUS = {
VersionMismatch: samlp.STATUS_VERSION_MISMATCH,
UnknownPrincipal: samlp.STATUS_UNKNOWN_PRINCIPAL,
UnsupportedBinding: samlp.STATUS_UNSUPPORTED_BINDING,
OtherError: samlp.STATUS_UNKNOWN_PRINCIPAL,
MissingValue: samlp.STATUS_REQUEST_UNSUPPORTED,
}
GENERIC_DOMAINS = "aero", "asia", "biz", "cat", "com", "coop", \
"edu", "gov", "info", "int", "jobs", "mil", "mobi", "museum", \
"name", "net", "org", "pro", "tel", "travel"
def valid_email(emailaddress, domains = GENERIC_DOMAINS):
"""Checks for a syntactically valid email address."""
# Email address must be at least 6 characters in total.
# Assuming noone may have addresses of the type a@com
if len(emailaddress) < 6:
return False # Address too short.
# Split up email address into parts.
try:
localpart, domainname = emailaddress.rsplit('@', 1)
host, toplevel = domainname.rsplit('.', 1)
except ValueError:
return False # Address does not have enough parts.
# Check for Country code or Generic Domain.
if len(toplevel) != 2 and toplevel not in domains:
return False # Not a domain name.
for i in '-_.%+.':
localpart = localpart.replace(i, "")
for i in '-_.':
host = host.replace(i, "")
if localpart.isalnum() and host.isalnum():
return True # Email address is fine.
else:
return False # Email address has funny characters.
def decode_base64_and_inflate( string ):
""" base64 decodes and then inflates according to RFC1951
:param string: a deflated and encoded string
:return: the string after decoding and inflating
"""
return zlib.decompress( base64.b64decode( string ) , -15)
def deflate_and_base64_encode( string_val ):
"""
Deflates and the base64 encodes a string
:param string_val: The string to deflate and encode
:return: The deflated and encoded string
"""
return base64.b64encode( zlib.compress( string_val )[2:-4] )
def sid(seed=""):
"""The hash of the server time + seed makes an unique SID for each session.
:param seed: A seed string
:return: The hex version of the digest
"""
ident = md5()
ident.update(repr(time.time()))
if seed:
ident.update(seed)
return ident.hexdigest()
def parse_attribute_map(filenames):
"""
Expects a file with each line being composed of the oid for the attribute
exactly one space, a user friendly name of the attribute and then
the type specification of the name.
:param filename: List of filenames on mapfiles.
:return: A 2-tuple, one dictionary with the oid as keys and the friendly
names as values, the other one the other way around.
"""
forward = {}
backward = {}
for filename in filenames:
for line in open(filename).readlines():
(name, friendly_name, name_format) = line.strip().split()
forward[(name, name_format)] = friendly_name
backward[friendly_name] = (name, name_format)
return (forward, backward)
def identity_attribute(form, attribute, forward_map=None):
if form == "friendly":
if attribute.friendly_name:
return attribute.friendly_name
elif forward_map:
try:
return forward_map[(attribute.name, attribute.name_format)]
except KeyError:
return attribute.name
# default is name
return attribute.name
#----------------------------------------------------------------------------
def status_from_exception_factory(exception):
msg = exception.args[0]
status = samlp.Status(
status_message=samlp.StatusMessage(text=msg),
status_code=samlp.StatusCode(
value=samlp.STATUS_RESPONDER,
status_code=samlp.StatusCode(
value=EXCEPTION2STATUS[exception.__class__])
),
)
return status
def success_status_factory():
return samlp.Status(status_code=samlp.StatusCode(
value=samlp.STATUS_SUCCESS))
def status_message_factory(message, code, fro=samlp.STATUS_RESPONDER):
return samlp.Status(
status_message=samlp.StatusMessage(text=message),
status_code=samlp.StatusCode(
value=fro,
status_code=samlp.StatusCode(value=code)))
def assertion_factory(**kwargs):
assertion = saml.Assertion(version=VERSION, id=sid(),
issue_instant=instant())
for key, val in kwargs.items():
setattr(assertion, key, val)
return assertion
def response_factory(signature=False, encrypt=False, **kwargs):
response = samlp.Response(id=sid(), version=VERSION,
issue_instant=instant())
if signature:
response["signature"] = sigver.pre_signature_part(kwargs["id"])
if encrypt:
pass
for key, val in kwargs.items():
setattr(response, key, val)
return response
def _attrval(val, typ=""):
if isinstance(val, list) or isinstance(val, set):
attrval = [saml.AttributeValue(text=v) for v in val]
elif val == None:
attrval = None
else:
attrval = [saml.AttributeValue(text=val)]
if typ:
for ava in attrval:
ava.set_type(typ)
return attrval
# --- attribute profiles -----
# xmlns:xs="http://www.w3.org/2001/XMLSchema"
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
def do_ava(val, typ=""):
if isinstance(val, basestring):
ava = saml.AttributeValue()
ava.set_text(val)
attrval = [ava]
elif isinstance(val, list):
attrval = [do_ava(v)[0] for v in val]
elif val or val == False:
ava = saml.AttributeValue()
ava.set_text(val)
attrval = [ava]
elif val == None:
attrval = None
else:
raise OtherError("strange value type on: %s" % val)
if typ:
for ava in attrval:
ava.set_type(typ)
return attrval
def do_attribute(val, typ, key):
attr = saml.Attribute()
attrval = do_ava(val, typ)
if attrval:
attr.attribute_value = attrval
if isinstance(key, basestring):
attr.name = key
elif isinstance(key, tuple): # 3-tuple or 2-tuple
try:
(name, nformat, friendly) = key
except ValueError:
(name, nformat) = key
friendly = ""
if name:
attr.name = name
if format:
attr.name_format = nformat
if friendly:
attr.friendly_name = friendly
return attr
def do_attributes(identity):
attrs = []
if not identity:
return attrs
for key, spec in identity.items():
try:
val, typ = spec
except ValueError:
val = spec
typ = ""
except TypeError:
val = ""
typ = ""
attr = do_attribute(val, typ, key)
attrs.append(attr)
return attrs
def do_attribute_statement(identity):
"""
:param identity: A dictionary with fiendly names as keys
:return:
"""
return saml.AttributeStatement(attribute=do_attributes(identity))
def factory(klass, **kwargs):
instance = klass()
for key, val in kwargs.items():
setattr(instance, key, val)
return instance

60
src/saml2/virtual_org.py Normal file
View File

@@ -0,0 +1,60 @@
from saml2.attribute_resolver import AttributeResolver
class VirtualOrg(object):
def __init__(self, metadata, vo_org, population, log=None, vorg_conf=None):
self.metadata = metadata
self.log = log
self.vorg_conf = vorg_conf
self.vorg = vo_org
self.population = population
def members_to_ask(self, subject_id):
# Find the member of the Virtual Organization that I haven't
# alrady spoken too
vo_members = [
member for member in self.metadata.vo_members(self.vorg)\
if member not in self.srv["idp"].keys()]
self.log and self.log.info("VO members: %s" % vo_members)
# Remove the ones I have cached data from about this subject
vo_members = [m for m in vo_members \
if not self.cache.active(subject_id, m)]
self.log and self.log.info(
"VO members (not cached): %s" % vo_members)
return vo_members
def do_aggregation(self, subject_id):
if self.log:
self.log.info("** Do VO aggregation **")
self.log.info("SubjectID: %s, VO:%s" % (subject_id, self.vorg))
vo_members = self.members_to_ask(subject_id)
if vo_members:
# Find the NameIDFormat and the SPNameQualifier
if self.vorg_conf and "name_id_format" in self.vorg_conf:
name_id_format = self.vorg_conf["name_id_format"]
sp_name_qualifier = ""
else:
sp_name_qualifier = self.vorg
name_id_format = ""
resolver = AttributeResolver(environ, self.metadata, self.conf)
# extends returns a list of session_infos
for session_info in resolver.extend(subject_id,
self.conf["entityid"], vo_members,
name_id_format=name_id_format,
sp_name_qualifier=sp_name_qualifier,
log=self.log):
_ignore = self._cache_session(session_info)
if self.log:
self.log.info(
">Issuers: %s" % self.population.issuers_of_info(subject_id))
self.log.info(
"AVA: %s" % (self.population.get_identity(subject_id),))
return True
else:
return False