add_header method and returning a byte string.

Fixed some PY3 problems
service_per_endpoint method useful when setting up an entity instance.
This commit is contained in:
Roland Hedberg 2016-04-11 16:17:20 +02:00
parent 2ba6dd7cba
commit 61afe88cd9
5 changed files with 87 additions and 62 deletions

View File

@ -65,9 +65,6 @@ def dict_to_table(ava, lev=0, width=1):
txt.append("<tr>\n") txt.append("<tr>\n")
if isinstance(valarr, six.string_types): if isinstance(valarr, six.string_types):
txt.append("<th>%s</th>\n" % str(prop)) txt.append("<th>%s</th>\n" % str(prop))
try:
txt.append("<td>%s</td>\n" % valarr.encode("utf8"))
except AttributeError:
txt.append("<td>%s</td>\n" % valarr) txt.append("<td>%s</td>\n" % valarr)
elif isinstance(valarr, list): elif isinstance(valarr, list):
i = 0 i = 0
@ -82,9 +79,6 @@ def dict_to_table(ava, lev=0, width=1):
txt.extend(dict_to_table(val, lev + 1, width - 1)) txt.extend(dict_to_table(val, lev + 1, width - 1))
txt.append("</td>\n") txt.append("</td>\n")
else: else:
try:
txt.append("<td>%s</td>\n" % val.encode("utf8"))
except AttributeError:
txt.append("<td>%s</td>\n" % val) txt.append("<td>%s</td>\n" % val)
if n > 1: if n > 1:
txt.append("</tr>\n") txt.append("</tr>\n")
@ -206,7 +200,7 @@ class Cache(object):
cookie[self.cookie_name]['path'] = "/" cookie[self.cookie_name]['path'] = "/"
cookie[self.cookie_name]["expires"] = _expiration(480) cookie[self.cookie_name]["expires"] = _expiration(480)
logger.debug("Cookie expires: %s", cookie[self.cookie_name]["expires"]) logger.debug("Cookie expires: %s", cookie[self.cookie_name]["expires"])
return cookie.output().encode("UTF-8").split(": ", 1) return cookie.output().split(": ", 1)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -230,12 +224,9 @@ class Service(object):
return None return None
def unpack_post(self): def unpack_post(self):
_dict = parse_qs(get_post(self.environ)) _dict = parse_qs(get_post(self.environ).decode('utf8'))
logger.debug("unpack_post:: %s", _dict) logger.debug("unpack_post:: %s", _dict)
try:
return dict([(k, v[0]) for k, v in _dict.items()]) return dict([(k, v[0]) for k, v in _dict.items()])
except Exception:
return None
def unpack_soap(self): def unpack_soap(self):
try: try:
@ -544,7 +535,7 @@ class SSO(object):
return -1, SeeOther(loc) return -1, SeeOther(loc)
elif len(idps) == 1: elif len(idps) == 1:
# idps is a dictionary # idps is a dictionary
idp_entity_id = idps.keys()[0] idp_entity_id = list(idps.keys())[0]
elif not len(idps): elif not len(idps):
return -1, ServiceError('Misconfiguration') return -1, ServiceError('Misconfiguration')
else: else:

View File

@ -1,8 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
from saml2.saml import NAME_FORMAT_URI
__author__ = 'rolandh'
import copy import copy
import sys import sys
import os import os
@ -11,7 +7,7 @@ import logging
import logging.handlers import logging.handlers
import six import six
from importlib import import_module from future.backports.test.support import import_module
from saml2 import root_logger, BINDING_URI, SAMLError from saml2 import root_logger, BINDING_URI, SAMLError
from saml2 import BINDING_SOAP from saml2 import BINDING_SOAP
@ -22,10 +18,13 @@ from saml2 import BINDING_HTTP_ARTIFACT
from saml2.attribute_converter import ac_factory from saml2.attribute_converter import ac_factory
from saml2.assertion import Policy from saml2.assertion import Policy
from saml2.mdstore import MetadataStore from saml2.mdstore import MetadataStore
from saml2.saml import NAME_FORMAT_URI
from saml2.virtual_org import VirtualOrg from saml2.virtual_org import VirtualOrg
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__author__ = 'rolandh'
COMMON_ARGS = [ COMMON_ARGS = [
"entityid", "xmlsec_binary", "debug", "key_file", "cert_file", "entityid", "xmlsec_binary", "debug", "key_file", "cert_file",
@ -493,6 +492,22 @@ class Config(object):
for key, val in extensions.items(): for key, val in extensions.items():
self.extensions[key] = val self.extensions[key] = val
def service_per_endpoint(self, context=None):
"""
List all endpoint this entity publishes and which service and binding
that are behind the endpoint
:param context: Type of entity
:return: Dictionary with endpoint url as key and a tuple of
service and binding as value
"""
endps = self.getattr("endpoints", context)
res = {}
for service, specs in endps.items():
for endp, binding in specs:
res[endp] = (service, binding)
return res
class SPConfig(Config): class SPConfig(Config):
def_context = "sp" def_context = "sp"

View File

@ -64,9 +64,21 @@ class Response(object):
else: else:
if isinstance(message, six.string_types): if isinstance(message, six.string_types):
return [message] return [message]
elif isinstance(message, six.binary_type):
return [message]
else: else:
return message return message
def add_header(self, ava):
"""
Does *NOT* replace a header of the same type, just adds a new
:param ava: (type, value) tuple
"""
self.headers.append(ava)
def reply(self, **kwargs):
return self.response(self.message, **kwargs)
class Created(Response): class Created(Response):
_status = "201 Created" _status = "201 Created"

View File

@ -6,10 +6,13 @@ import sys
import json import json
import requests import requests
import six import six
from hashlib import sha1 from hashlib import sha1
from os.path import isfile from os.path import isfile
from os.path import join from os.path import join
from future.backports.test.support import import_module
from saml2 import md from saml2 import md
from saml2 import saml from saml2 import saml
from saml2 import samlp from saml2 import samlp
@ -32,7 +35,6 @@ from saml2.validate import valid_instance
from saml2.time_util import valid from saml2.time_util import valid
from saml2.validate import NotValid from saml2.validate import NotValid
from saml2.sigver import security_context from saml2.sigver import security_context
from importlib import import_module
__author__ = 'rolandh' __author__ = 'rolandh'
@ -225,7 +227,7 @@ class MetaData(object):
''' '''
raise NotImplementedError raise NotImplementedError
def load(self): def load(self, *args, **kwargs):
''' '''
Loads the metadata Loads the metadata
''' '''
@ -634,7 +636,7 @@ class MetaDataFile(InMemoryMetaData):
def get_metadata_content(self): def get_metadata_content(self):
return open(self.filename, 'rb').read() return open(self.filename, 'rb').read()
def load(self): def load(self, *args, **kwargs):
_txt = self.get_metadata_content() _txt = self.get_metadata_content()
return self.parse_and_check_signature(_txt) return self.parse_and_check_signature(_txt)
@ -655,7 +657,7 @@ class MetaDataLoader(MetaDataFile):
@staticmethod @staticmethod
def get_metadata_loader(func): def get_metadata_loader(func):
if callable(func): if hasattr(func, '__call__'):
return func return func
i = func.rfind('.') i = func.rfind('.')
@ -673,7 +675,7 @@ class MetaDataLoader(MetaDataFile):
'Module "%s" does not define a "%s" metadata loader' % ( 'Module "%s" does not define a "%s" metadata loader' % (
module, attr)) module, attr))
if not callable(metadata_loader): if not hasattr(metadata_loader, '__call__'):
raise RuntimeError( raise RuntimeError(
'Metadata loader %s.%s must be callable' % (module, attr)) 'Metadata loader %s.%s must be callable' % (module, attr))
@ -710,7 +712,7 @@ class MetaDataExtern(InMemoryMetaData):
self.security = security self.security = security
self.http = http self.http = http
def load(self): def load(self, *args, **kwargs):
""" Imports metadata by the use of HTTP GET. """ Imports metadata by the use of HTTP GET.
If the fingerprint is known the file will be checked for If the fingerprint is known the file will be checked for
compliance before it is imported. compliance before it is imported.
@ -734,7 +736,7 @@ class MetaDataMD(InMemoryMetaData):
super(MetaDataMD, self).__init__(attrc, **kwargs) super(MetaDataMD, self).__init__(attrc, **kwargs)
self.filename = filename self.filename = filename
def load(self): def load(self, *args, **kwargs):
for key, item in json.loads(open(self.filename).read()): for key, item in json.loads(open(self.filename).read()):
self.entity[key] = item self.entity[key] = item
@ -760,7 +762,7 @@ class MetaDataMDX(InMemoryMetaData):
concatenated with the request URL sent to the MDX server. Defaults to concatenated with the request URL sent to the MDX server. Defaults to
sha1 transformation. sha1 transformation.
""" """
super(MetaDataMDX, self).__init__(None, None) super(MetaDataMDX, self).__init__(None, '')
self.url = url self.url = url
if entity_transform: if entity_transform:
@ -769,7 +771,7 @@ class MetaDataMDX(InMemoryMetaData):
self.entity_transform = MetaDataMDX.sha1_entity_transform self.entity_transform = MetaDataMDX.sha1_entity_transform
def load(self): def load(self, *args, **kwargs):
# Do nothing # Do nothing
pass pass
@ -807,7 +809,7 @@ class MetadataStore(MetaData):
:params ca_certs: :params ca_certs:
:params disable_ssl_certificate_validation: :params disable_ssl_certificate_validation:
""" """
self.attrc = attrc MetaData.__init__(self, attrc, check_validity=check_validity)
if disable_ssl_certificate_validation: if disable_ssl_certificate_validation:
self.http = HTTPBase(verify=False, ca_bundle=ca_certs) self.http = HTTPBase(verify=False, ca_bundle=ca_certs)
@ -821,12 +823,13 @@ class MetadataStore(MetaData):
self.filter = filter self.filter = filter
self.to_old = {} self.to_old = {}
def load(self, typ, *args, **kwargs): def load(self, *args, **kwargs):
if self.filter: if self.filter:
_args = {"filter": self.filter} _args = {"filter": self.filter}
else: else:
_args = {} _args = {}
typ = args[0]
if typ == "local": if typ == "local":
key = args[0] key = args[0]
# if library read every file in the library # if library read every file in the library

View File

@ -5,9 +5,11 @@ from saml2 import attribute_converter, saml
from attribute_statement_data import * from attribute_statement_data import *
from pathutils import full_path from pathutils import full_path
from saml2.attribute_converter import AttributeConverterNOOP, AttributeConverter from saml2.attribute_converter import AttributeConverterNOOP
from saml2.attribute_converter import AttributeConverter
from saml2.attribute_converter import to_local from saml2.attribute_converter import to_local
from saml2.saml import attribute_from_string from saml2.saml import attribute_from_string
from saml2.saml import attribute_statement_from_string
def _eq(l1, l2): def _eq(l1, l2):
@ -134,22 +136,20 @@ class TestAC():
assert _eq(lan, ['sn', 'givenName', 'title']) assert _eq(lan, ['sn', 'givenName', 'title'])
# def test_ava_fro_1(self): def test_to_local_name_from_unspecified(self):
# _xml = """<?xml version='1.0' encoding='UTF-8'?>
# attr = [saml.Attribute(friendly_name="surName", <ns0:AttributeStatement xmlns:ns0="urn:oasis:names:tc:SAML:2.0:assertion">
# name="urn:oid:2.5.4.4", <ns0:Attribute
# name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
# saml.Attribute(friendly_name="efternamn", Name="EmailAddress"
# name="urn:oid:2.5.4.42", NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified">
# name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"), <ns0:AttributeValue xsi:type="xs:string">foo@bar.com</ns0:AttributeValue>
# saml.Attribute(friendly_name="titel", </ns0:Attribute></ns0:AttributeStatement>"""
# name="urn:oid:2.5.4.12",
# name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")] attr = attribute_statement_from_string(_xml)
# ava = attribute_converter.to_local(self.acs, attr)
# result = attribute_converter.ava_fro(self.acs, attr)
# assert _eq(list(ava.keys()), ['EmailAddress'])
# print(result)
# assert result == {'givenName': [], 'sn': [], 'title': []}
def test_to_local_name_from_basic(self): def test_to_local_name_from_basic(self):
attr = [ attr = [
@ -229,7 +229,12 @@ def test_noop_attribute_conversion():
ava = """<?xml version='1.0' encoding='UTF-8'?> ava = """<?xml version='1.0' encoding='UTF-8'?>
<ns0:Attribute xmlns:ns0="urn:oasis:names:tc:SAML:2.0:assertion" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" FriendlyName="schacHomeOrganization" Name="urn:oid:1.3.6.1.4.1.25178.1.2.9" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"><ns0:AttributeValue xsi:nil="true" xsi:type="xs:string">uu.se</ns0:AttributeValue></ns0:Attribute>""" <ns0:Attribute xmlns:ns0="urn:oasis:names:tc:SAML:2.0:assertion"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
FriendlyName="schacHomeOrganization" Name="urn:oid:1.3.6.1.4.1.25178.1.2.9"
NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"><ns0
:AttributeValue xsi:nil="true"
xsi:type="xs:string">uu.se</ns0:AttributeValue></ns0:Attribute>"""
def test_schac(): def test_schac():
@ -246,5 +251,4 @@ def test_schac():
if __name__ == "__main__": if __name__ == "__main__":
t = TestAC() t = TestAC()
t.setup_class() t.setup_class()
t.test_to_attrstat_1() t.test_to_local_name_from_unspecified()
# test_schac()