add_header method and returning a byte string.

Fixed some PY3 problems
service_per_endpoint method useful when setting up an entity instance.
This commit is contained in:
Roland Hedberg 2016-04-11 16:17:20 +02:00
parent 2ba6dd7cba
commit 61afe88cd9
5 changed files with 87 additions and 62 deletions

View File

@ -65,9 +65,6 @@ def dict_to_table(ava, lev=0, width=1):
txt.append("<tr>\n")
if isinstance(valarr, six.string_types):
txt.append("<th>%s</th>\n" % str(prop))
try:
txt.append("<td>%s</td>\n" % valarr.encode("utf8"))
except AttributeError:
txt.append("<td>%s</td>\n" % valarr)
elif isinstance(valarr, list):
i = 0
@ -82,9 +79,6 @@ def dict_to_table(ava, lev=0, width=1):
txt.extend(dict_to_table(val, lev + 1, width - 1))
txt.append("</td>\n")
else:
try:
txt.append("<td>%s</td>\n" % val.encode("utf8"))
except AttributeError:
txt.append("<td>%s</td>\n" % val)
if n > 1:
txt.append("</tr>\n")
@ -206,7 +200,7 @@ class Cache(object):
cookie[self.cookie_name]['path'] = "/"
cookie[self.cookie_name]["expires"] = _expiration(480)
logger.debug("Cookie expires: %s", cookie[self.cookie_name]["expires"])
return cookie.output().encode("UTF-8").split(": ", 1)
return cookie.output().split(": ", 1)
# -----------------------------------------------------------------------------
@ -230,12 +224,9 @@ class Service(object):
return None
def unpack_post(self):
_dict = parse_qs(get_post(self.environ))
_dict = parse_qs(get_post(self.environ).decode('utf8'))
logger.debug("unpack_post:: %s", _dict)
try:
return dict([(k, v[0]) for k, v in _dict.items()])
except Exception:
return None
def unpack_soap(self):
try:
@ -544,7 +535,7 @@ class SSO(object):
return -1, SeeOther(loc)
elif len(idps) == 1:
# idps is a dictionary
idp_entity_id = idps.keys()[0]
idp_entity_id = list(idps.keys())[0]
elif not len(idps):
return -1, ServiceError('Misconfiguration')
else:

View File

@ -1,8 +1,4 @@
#!/usr/bin/env python
from saml2.saml import NAME_FORMAT_URI
__author__ = 'rolandh'
import copy
import sys
import os
@ -11,7 +7,7 @@ import logging
import logging.handlers
import six
from importlib import import_module
from future.backports.test.support import import_module
from saml2 import root_logger, BINDING_URI, SAMLError
from saml2 import BINDING_SOAP
@ -22,10 +18,13 @@ from saml2 import BINDING_HTTP_ARTIFACT
from saml2.attribute_converter import ac_factory
from saml2.assertion import Policy
from saml2.mdstore import MetadataStore
from saml2.saml import NAME_FORMAT_URI
from saml2.virtual_org import VirtualOrg
logger = logging.getLogger(__name__)
__author__ = 'rolandh'
COMMON_ARGS = [
"entityid", "xmlsec_binary", "debug", "key_file", "cert_file",
@ -493,6 +492,22 @@ class Config(object):
for key, val in extensions.items():
self.extensions[key] = val
def service_per_endpoint(self, context=None):
"""
List all endpoint this entity publishes and which service and binding
that are behind the endpoint
:param context: Type of entity
:return: Dictionary with endpoint url as key and a tuple of
service and binding as value
"""
endps = self.getattr("endpoints", context)
res = {}
for service, specs in endps.items():
for endp, binding in specs:
res[endp] = (service, binding)
return res
class SPConfig(Config):
def_context = "sp"

View File

@ -64,9 +64,21 @@ class Response(object):
else:
if isinstance(message, six.string_types):
return [message]
elif isinstance(message, six.binary_type):
return [message]
else:
return message
def add_header(self, ava):
"""
Does *NOT* replace a header of the same type, just adds a new
:param ava: (type, value) tuple
"""
self.headers.append(ava)
def reply(self, **kwargs):
return self.response(self.message, **kwargs)
class Created(Response):
_status = "201 Created"

View File

@ -6,10 +6,13 @@ import sys
import json
import requests
import six
from hashlib import sha1
from os.path import isfile
from os.path import join
from future.backports.test.support import import_module
from saml2 import md
from saml2 import saml
from saml2 import samlp
@ -32,7 +35,6 @@ from saml2.validate import valid_instance
from saml2.time_util import valid
from saml2.validate import NotValid
from saml2.sigver import security_context
from importlib import import_module
__author__ = 'rolandh'
@ -225,7 +227,7 @@ class MetaData(object):
'''
raise NotImplementedError
def load(self):
def load(self, *args, **kwargs):
'''
Loads the metadata
'''
@ -634,7 +636,7 @@ class MetaDataFile(InMemoryMetaData):
def get_metadata_content(self):
return open(self.filename, 'rb').read()
def load(self):
def load(self, *args, **kwargs):
_txt = self.get_metadata_content()
return self.parse_and_check_signature(_txt)
@ -655,7 +657,7 @@ class MetaDataLoader(MetaDataFile):
@staticmethod
def get_metadata_loader(func):
if callable(func):
if hasattr(func, '__call__'):
return func
i = func.rfind('.')
@ -673,7 +675,7 @@ class MetaDataLoader(MetaDataFile):
'Module "%s" does not define a "%s" metadata loader' % (
module, attr))
if not callable(metadata_loader):
if not hasattr(metadata_loader, '__call__'):
raise RuntimeError(
'Metadata loader %s.%s must be callable' % (module, attr))
@ -710,7 +712,7 @@ class MetaDataExtern(InMemoryMetaData):
self.security = security
self.http = http
def load(self):
def load(self, *args, **kwargs):
""" Imports metadata by the use of HTTP GET.
If the fingerprint is known the file will be checked for
compliance before it is imported.
@ -734,7 +736,7 @@ class MetaDataMD(InMemoryMetaData):
super(MetaDataMD, self).__init__(attrc, **kwargs)
self.filename = filename
def load(self):
def load(self, *args, **kwargs):
for key, item in json.loads(open(self.filename).read()):
self.entity[key] = item
@ -760,7 +762,7 @@ class MetaDataMDX(InMemoryMetaData):
concatenated with the request URL sent to the MDX server. Defaults to
sha1 transformation.
"""
super(MetaDataMDX, self).__init__(None, None)
super(MetaDataMDX, self).__init__(None, '')
self.url = url
if entity_transform:
@ -769,7 +771,7 @@ class MetaDataMDX(InMemoryMetaData):
self.entity_transform = MetaDataMDX.sha1_entity_transform
def load(self):
def load(self, *args, **kwargs):
# Do nothing
pass
@ -807,7 +809,7 @@ class MetadataStore(MetaData):
:params ca_certs:
:params disable_ssl_certificate_validation:
"""
self.attrc = attrc
MetaData.__init__(self, attrc, check_validity=check_validity)
if disable_ssl_certificate_validation:
self.http = HTTPBase(verify=False, ca_bundle=ca_certs)
@ -821,12 +823,13 @@ class MetadataStore(MetaData):
self.filter = filter
self.to_old = {}
def load(self, typ, *args, **kwargs):
def load(self, *args, **kwargs):
if self.filter:
_args = {"filter": self.filter}
else:
_args = {}
typ = args[0]
if typ == "local":
key = args[0]
# if library read every file in the library

View File

@ -5,9 +5,11 @@ from saml2 import attribute_converter, saml
from attribute_statement_data import *
from pathutils import full_path
from saml2.attribute_converter import AttributeConverterNOOP, AttributeConverter
from saml2.attribute_converter import AttributeConverterNOOP
from saml2.attribute_converter import AttributeConverter
from saml2.attribute_converter import to_local
from saml2.saml import attribute_from_string
from saml2.saml import attribute_statement_from_string
def _eq(l1, l2):
@ -134,22 +136,20 @@ class TestAC():
assert _eq(lan, ['sn', 'givenName', 'title'])
# def test_ava_fro_1(self):
#
# attr = [saml.Attribute(friendly_name="surName",
# name="urn:oid:2.5.4.4",
# name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"),
# saml.Attribute(friendly_name="efternamn",
# name="urn:oid:2.5.4.42",
# name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"),
# saml.Attribute(friendly_name="titel",
# name="urn:oid:2.5.4.12",
# name_format="urn:oasis:names:tc:SAML:2.0:attrname-format:uri")]
#
# result = attribute_converter.ava_fro(self.acs, attr)
#
# print(result)
# assert result == {'givenName': [], 'sn': [], 'title': []}
def test_to_local_name_from_unspecified(self):
_xml = """<?xml version='1.0' encoding='UTF-8'?>
<ns0:AttributeStatement xmlns:ns0="urn:oasis:names:tc:SAML:2.0:assertion">
<ns0:Attribute
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
Name="EmailAddress"
NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified">
<ns0:AttributeValue xsi:type="xs:string">foo@bar.com</ns0:AttributeValue>
</ns0:Attribute></ns0:AttributeStatement>"""
attr = attribute_statement_from_string(_xml)
ava = attribute_converter.to_local(self.acs, attr)
assert _eq(list(ava.keys()), ['EmailAddress'])
def test_to_local_name_from_basic(self):
attr = [
@ -229,7 +229,12 @@ def test_noop_attribute_conversion():
ava = """<?xml version='1.0' encoding='UTF-8'?>
<ns0:Attribute xmlns:ns0="urn:oasis:names:tc:SAML:2.0:assertion" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" FriendlyName="schacHomeOrganization" Name="urn:oid:1.3.6.1.4.1.25178.1.2.9" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"><ns0:AttributeValue xsi:nil="true" xsi:type="xs:string">uu.se</ns0:AttributeValue></ns0:Attribute>"""
<ns0:Attribute xmlns:ns0="urn:oasis:names:tc:SAML:2.0:assertion"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
FriendlyName="schacHomeOrganization" Name="urn:oid:1.3.6.1.4.1.25178.1.2.9"
NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:uri"><ns0
:AttributeValue xsi:nil="true"
xsi:type="xs:string">uu.se</ns0:AttributeValue></ns0:Attribute>"""
def test_schac():
@ -246,5 +251,4 @@ def test_schac():
if __name__ == "__main__":
t = TestAC()
t.setup_class()
t.test_to_attrstat_1()
# test_schac()
t.test_to_local_name_from_unspecified()