Merge branch 'master' of github.com:rohe/pysaml2

This commit is contained in:
Roland Hedberg 2017-04-24 15:46:35 +02:00
commit 91bce29d7f
19 changed files with 274 additions and 195 deletions

4
.gitignore vendored
View File

@ -33,6 +33,8 @@ tmp*
_build/ _build/
.cache .cache
*.swp *.swp
.tox
env
example/idp3/htdocs/login.mako example/idp3/htdocs/login.mako
@ -192,8 +194,6 @@ example/sp-repoze/old_sp.xml
example/sp-repoze/sp_conf_2.Pygmalion example/sp-repoze/sp_conf_2.Pygmalion
.gitignore.swp
example/sp-repoze/sp_conf_2.py example/sp-repoze/sp_conf_2.py
sp.xml sp.xml

View File

@ -3,7 +3,7 @@ PySAML2 - SAML2 in Python
************************* *************************
:Author: Roland Hedberg :Author: Roland Hedberg
:Version: 4.0.4 :Version: 4.4.0
.. image:: https://api.travis-ci.org/rohe/pysaml2.png?branch=master .. image:: https://api.travis-ci.org/rohe/pysaml2.png?branch=master
:target: https://travis-ci.org/rohe/pysaml2 :target: https://travis-ci.org/rohe/pysaml2
@ -26,3 +26,15 @@ necessary pieces for building a SAML2 service provider or an identity provider.
The distribution contains examples of both. The distribution contains examples of both.
Originally written to work in a WSGI environment there are extensions that Originally written to work in a WSGI environment there are extensions that
allow you to use it with other frameworks. allow you to use it with other frameworks.
Testing
=======
PySAML2 uses the `pytest <http://doc.pytest.org/en/latest/>`_ framework for
testing. To run the tests on your system's version of python
1. Create and activate a `virtualenv <https://virtualenv.pypa.io/en/stable/>`_.
2. Inside the virtualenv, install the dependencies needed for testing :code:`pip install -r tests/test_requirements.txt`
3. Run the tests :code:`py.test tests`
To run tests in multiple python environments, you can use
`pyenv <https://github.com/yyuu/pyenv>`_ with `tox <https://tox.readthedocs.io/en/latest/>`_.

View File

@ -14,10 +14,11 @@ install_requires = [
'paste', 'paste',
'zope.interface', 'zope.interface',
'repoze.who', 'repoze.who',
'pycryptodomex', 'cryptography',
'pytz', 'pytz',
'pyOpenSSL', 'pyOpenSSL',
'python-dateutil', 'python-dateutil',
'defusedxml',
'six' 'six'
] ]

View File

@ -17,7 +17,7 @@
provides methods and functions to convert SAML classes to and from strings. provides methods and functions to convert SAML classes to and from strings.
""" """
__version__ = "4.3.0" __version__ = "4.4.0"
import logging import logging
import six import six
@ -36,6 +36,7 @@ except ImportError:
import cElementTree as ElementTree import cElementTree as ElementTree
except ImportError: except ImportError:
from elementtree import ElementTree from elementtree import ElementTree
import defusedxml.ElementTree
root_logger = logging.getLogger(__name__) root_logger = logging.getLogger(__name__)
root_logger.level = logging.NOTSET root_logger.level = logging.NOTSET
@ -87,7 +88,7 @@ def create_class_from_xml_string(target_class, xml_string):
""" """
if not isinstance(xml_string, six.binary_type): if not isinstance(xml_string, six.binary_type):
xml_string = xml_string.encode('utf-8') xml_string = xml_string.encode('utf-8')
tree = ElementTree.fromstring(xml_string) tree = defusedxml.ElementTree.fromstring(xml_string)
return create_class_from_element_tree(target_class, tree) return create_class_from_element_tree(target_class, tree)
@ -269,7 +270,7 @@ class ExtensionElement(object):
def extension_element_from_string(xml_string): def extension_element_from_string(xml_string):
element_tree = ElementTree.fromstring(xml_string) element_tree = defusedxml.ElementTree.fromstring(xml_string)
return _extension_element_from_element_tree(element_tree) return _extension_element_from_element_tree(element_tree)

View File

@ -8,7 +8,11 @@ import six
from OpenSSL import crypto from OpenSSL import crypto
from os.path import join from os.path import join
from os import remove from os import remove
from Cryptodome.Util import asn1
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate
backend = default_backend()
class WrongInput(Exception): class WrongInput(Exception):
pass pass
@ -194,9 +198,8 @@ class OpenSSLWrapper(object):
f.close() f.close()
def read_str_from_file(self, file, type="pem"): def read_str_from_file(self, file, type="pem"):
f = open(file, 'rt') with open(file, 'rb') as f:
str_data = f.read() str_data = f.read()
f.close()
if type == "pem": if type == "pem":
return str_data return str_data
@ -336,31 +339,13 @@ class OpenSSLWrapper(object):
cert_algorithm = cert.get_signature_algorithm() cert_algorithm = cert.get_signature_algorithm()
if six.PY3: if six.PY3:
cert_algorithm = cert_algorithm.decode('ascii') cert_algorithm = cert_algorithm.decode('ascii')
cert_str = cert_str.encode('ascii')
cert_asn1 = crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) cert_crypto = load_pem_x509_certificate(cert_str, backend)
der_seq = asn1.DerSequence()
der_seq.decode(cert_asn1)
cert_certificate = der_seq[0]
#cert_signature_algorithm=der_seq[1]
cert_signature = der_seq[2]
cert_signature_decoded = asn1.DerObject()
cert_signature_decoded.decode(cert_signature)
signature_payload = cert_signature_decoded.payload
sig_pay0 = signature_payload[0]
if ((isinstance(sig_pay0, int) and sig_pay0 != 0) or
(isinstance(sig_pay0, str) and sig_pay0 != '\x00')):
return (False,
"The certificate should not contain any unused bits.")
signature = signature_payload[1:]
try: try:
crypto.verify(ca_cert, signature, cert_certificate, crypto.verify(ca_cert, cert_crypto.signature,
cert_crypto.tbs_certificate_bytes,
cert_algorithm) cert_algorithm)
return True, "Signed certificate is valid and correctly signed by CA certificate." return True, "Signed certificate is valid and correctly signed by CA certificate."
except crypto.Error as e: except crypto.Error as e:

View File

@ -207,7 +207,7 @@ class Base(Entity):
nameid_format=None, nameid_format=None,
service_url_binding=None, message_id=0, service_url_binding=None, message_id=0,
consent=None, extensions=None, sign=None, consent=None, extensions=None, sign=None,
allow_create=False, sign_prepare=False, sign_alg=None, allow_create=None, sign_prepare=False, sign_alg=None,
digest_alg=None, **kwargs): digest_alg=None, **kwargs):
""" Creates an authentication request. """ Creates an authentication request.
@ -288,10 +288,15 @@ class Base(Entity):
args["name_id_policy"] = kwargs["name_id_policy"] args["name_id_policy"] = kwargs["name_id_policy"]
del kwargs["name_id_policy"] del kwargs["name_id_policy"]
except KeyError: except KeyError:
if allow_create: if allow_create is None:
allow_create = "true" allow_create = self.config.getattr("name_id_format_allow_create", "sp")
else: if allow_create is None:
allow_create = "false" allow_create = "false"
else:
if allow_create is True:
allow_create = "true"
else:
allow_create = "false"
if nameid_format == "": if nameid_format == "":
name_id_policy = None name_id_policy = None

View File

@ -1,13 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
import copy import copy
import sys import importlib
import os
import re
import logging import logging
import logging.handlers import logging.handlers
import six import os
import re
import sys
from future.backports.test.support import import_module import six
from saml2 import root_logger, BINDING_URI, SAMLError from saml2 import root_logger, BINDING_URI, SAMLError
from saml2 import BINDING_SOAP from saml2 import BINDING_SOAP
@ -72,6 +73,7 @@ SP_ARGS = [
"allow_unsolicited", "allow_unsolicited",
"ecp", "ecp",
"name_id_format", "name_id_format",
"name_id_format_allow_create",
"logout_requests_signed", "logout_requests_signed",
"requested_attribute_name_format" "requested_attribute_name_format"
] ]
@ -186,6 +188,7 @@ class Config(object):
self.contact_person = None self.contact_person = None
self.name_form = None self.name_form = None
self.name_id_format = None self.name_id_format = None
self.name_id_format_allow_create = None
self.virtual_organization = None self.virtual_organization = None
self.logger = None self.logger = None
self.only_use_keys_in_metadata = True self.only_use_keys_in_metadata = True
@ -359,7 +362,7 @@ class Config(object):
else: else:
sys.path.insert(0, head) sys.path.insert(0, head)
return import_module(tail) return importlib.import_module(tail)
def load_file(self, config_file, metadata_construction=False): def load_file(self, config_file, metadata_construction=False):
if config_file.endswith(".py"): if config_file.endswith(".py"):

View File

@ -7,9 +7,6 @@ import six
from binascii import hexlify from binascii import hexlify
from hashlib import sha1 from hashlib import sha1
# from Crypto.PublicKey import RSA
from Cryptodome.PublicKey import RSA
from saml2.metadata import ENDPOINTS from saml2.metadata import ENDPOINTS
from saml2.profile import paos, ecp from saml2.profile import paos, ecp
from saml2.soap import parse_soap_enveloped_saml_artifact_resolve from saml2.soap import parse_soap_enveloped_saml_artifact_resolve

View File

@ -1,17 +1,17 @@
from __future__ import print_function from __future__ import print_function
import hashlib import hashlib
import importlib
import json
import logging import logging
import os import os
import sys import sys
import json
import requests
import six
from hashlib import sha1 from hashlib import sha1
from os.path import isfile from os.path import isfile
from os.path import join from os.path import join
from future.backports.test.support import import_module import requests
import six
from saml2 import md from saml2 import md
from saml2 import saml from saml2 import saml
@ -694,7 +694,7 @@ class MetaDataLoader(MetaDataFile):
i = func.rfind('.') i = func.rfind('.')
module, attr = func[:i], func[i + 1:] module, attr = func[:i], func[i + 1:]
try: try:
mod = import_module(module) mod = importlib.import_module(module)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
'Cannot find metadata provider function %s: "%s"' % (func, e)) 'Cannot find metadata provider function %s: "%s"' % (func, e))
@ -930,7 +930,7 @@ class MetadataStore(MetaData):
raise SAMLError("Misconfiguration in metadata %s" % item) raise SAMLError("Misconfiguration in metadata %s" % item)
mod, clas = key.rsplit('.', 1) mod, clas = key.rsplit('.', 1)
try: try:
mod = import_module(mod) mod = importlib.import_module(mod)
MDloader = getattr(mod, clas) MDloader = getattr(mod, clas)
except (ImportError, AttributeError): except (ImportError, AttributeError):
raise SAMLError("Unknown metadata loader %s" % key) raise SAMLError("Unknown metadata loader %s" % key)

View File

@ -37,6 +37,7 @@ except ImportError:
import cElementTree as ElementTree import cElementTree as ElementTree
except ImportError: except ImportError:
from elementtree import ElementTree from elementtree import ElementTree
import defusedxml.ElementTree
NAMESPACE = "http://schemas.xmlsoap.org/soap/envelope/" NAMESPACE = "http://schemas.xmlsoap.org/soap/envelope/"
FORM_SPEC = """<form method="post" action="%s"> FORM_SPEC = """<form method="post" action="%s">
@ -235,7 +236,7 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None):
:param text: The SOAP object as XML :param text: The SOAP object as XML
:return: header parts and body as saml.samlbase instances :return: header parts and body as saml.samlbase instances
""" """
envelope = ElementTree.fromstring(text) envelope = defusedxml.ElementTree.fromstring(text)
assert envelope.tag == '{%s}Envelope' % NAMESPACE assert envelope.tag == '{%s}Envelope' % NAMESPACE
# print(len(envelope)) # print(len(envelope))

View File

@ -1,33 +1,23 @@
#!/usr/bin/env python #!/usr/bin/env python
import base64
import hashlib
import hmac
import logging import logging
import random import random
import time
import base64
import six
import sys
import hmac
import string import string
import sys
# from python 2.5 import time
import imp
import traceback import traceback
import zlib
if sys.version_info >= (2, 5): import six
import hashlib
else: # before python 2.5
import sha
from saml2 import saml from saml2 import saml
from saml2 import samlp from saml2 import samlp
from saml2 import VERSION from saml2 import VERSION
from saml2.time_util import instant from saml2.time_util import instant
try:
from hashlib import md5
except ImportError:
from md5 import md5
import zlib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -407,67 +397,6 @@ def verify_signature(secret, parts):
return False return False
FTICKS_FORMAT = "F-TICKS/SWAMID/2.0%s#"
def fticks_log(sp, logf, idp_entity_id, user_id, secret, assertion):
"""
'F-TICKS/' federationIdentifier '/' version *('#' attribute '=' value) '#'
Allowed attributes:
TS the login time stamp
RP the relying party entityID
AP the asserting party entityID (typcially the IdP)
PN a sha256-hash of the local principal name and a unique key
AM the authentication method URN
:param sp: Client instance
:param logf: The log function to use
:param idp_entity_id: IdP entity ID
:param user_id: The user identifier
:param secret: A salt to make the hash more secure
:param assertion: A SAML Assertion instance gotten from the IdP
"""
csum = hmac.new(secret, digestmod=hashlib.sha1)
csum.update(user_id)
ac = assertion.AuthnStatement[0].AuthnContext[0]
info = {
"TS": time.time(),
"RP": sp.entity_id,
"AP": idp_entity_id,
"PN": csum.hexdigest(),
"AM": ac.AuthnContextClassRef.text
}
logf.info(FTICKS_FORMAT % "#".join(["%s=%s" % (a, v) for a, v in info]))
def dynamic_importer(name, class_name=None):
"""
Dynamically imports modules / classes
"""
try:
fp, pathname, description = imp.find_module(name)
except ImportError:
print("unable to locate module: " + name)
return None, None
try:
package = imp.load_module(name, fp, pathname, description)
except Exception:
raise
if class_name:
try:
_class = imp.load_module("%s.%s" % (name, class_name), fp,
pathname, description)
except Exception:
raise
return package, _class
else:
return package, None
def exception_trace(exc): def exception_trace(exc):
message = traceback.format_exception(*sys.exc_info()) message = traceback.format_exception(*sys.exc_info())

View File

@ -19,25 +19,13 @@ from binascii import hexlify
from future.backports.urllib.parse import urlencode from future.backports.urllib.parse import urlencode
# from Crypto.PublicKey.RSA import importKey from cryptography.exceptions import InvalidSignature
# from Crypto.Signature import PKCS1_v1_5 from cryptography.hazmat.backends import default_backend
# from Crypto.Util.asn1 import DerSequence from cryptography.hazmat.primitives import hashes
# from Crypto.PublicKey import RSA from cryptography.hazmat.primitives.asymmetric import rsa
# from Crypto.Hash import SHA from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
# from Crypto.Hash import SHA224 from cryptography.hazmat.primitives.serialization import load_pem_private_key
# from Crypto.Hash import SHA256 from cryptography.x509 import load_pem_x509_certificate
# from Crypto.Hash import SHA384
# from Crypto.Hash import SHA512
from Cryptodome.PublicKey.RSA import importKey
from Cryptodome.Signature import PKCS1_v1_5
from Cryptodome.Util.asn1 import DerSequence
from Cryptodome.PublicKey import RSA
from Cryptodome.Hash import SHA
from Cryptodome.Hash import SHA224
from Cryptodome.Hash import SHA256
from Cryptodome.Hash import SHA384
from Cryptodome.Hash import SHA512
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from subprocess import Popen from subprocess import Popen
@ -87,6 +75,8 @@ XMLTAG = "<?xml version='1.0'?>"
PREFIX1 = "<?xml version='1.0' encoding='UTF-8'?>" PREFIX1 = "<?xml version='1.0' encoding='UTF-8'?>"
PREFIX2 = '<?xml version="1.0" encoding="UTF-8"?>' PREFIX2 = '<?xml version="1.0" encoding="UTF-8"?>'
backend = default_backend()
class SigverError(SAMLError): class SigverError(SAMLError):
pass pass
@ -406,18 +396,10 @@ def active_cert(key):
""" """
try: try:
cert_str = pem_format(key) cert_str = pem_format(key)
try: cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str)
certificate = importKey(cert_str) assert cert.has_expired() == 0
not_before = to_time(str(certificate.get_not_before())) assert not OpenSSLWrapper().certificate_not_valid_yet(cert)
not_after = to_time(str(certificate.get_not_after())) return True
assert not_before < utc_now()
assert not_after > utc_now()
return True
except:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str)
assert cert.has_expired() == 0
assert not OpenSSLWrapper().certificate_not_valid_yet(cert)
return True
except AssertionError: except AssertionError:
return False return False
except AttributeError: except AttributeError:
@ -555,19 +537,8 @@ def rsa_eq(key1, key2):
def extract_rsa_key_from_x509_cert(pem): def extract_rsa_key_from_x509_cert(pem):
# Convert from PEM to DER cert = load_pem_x509_certificate(pem, backend)
der = ssl.PEM_cert_to_DER_cert(pem.decode('ascii')) return cert.public_key()
# Extract subjectPublicKeyInfo field from X.509 certificate (see RFC3280)
cert = DerSequence()
cert.decode(der)
tbsCertificate = DerSequence()
tbsCertificate.decode(cert[0])
subjectPublicKeyInfo = tbsCertificate[6]
# Initialize RSA key
rsa_key = RSA.importKey(subjectPublicKeyInfo)
return rsa_key
def pem_format(key): def pem_format(key):
@ -576,7 +547,7 @@ def pem_format(key):
def import_rsa_key_from_file(filename): def import_rsa_key_from_file(filename):
return RSA.importKey(read_file(filename, 'r')) return load_pem_private_key(read_file(filename, 'rb'), None, backend)
def parse_xmlsec_output(output): def parse_xmlsec_output(output):
@ -622,25 +593,28 @@ class RSASigner(Signer):
if key is None: if key is None:
key = self.key key = self.key
h = self.digest.new(msg) return key.sign(msg, PKCS1v15(), self.digest)
signer = PKCS1_v1_5.new(key)
return signer.sign(h)
def verify(self, msg, sig, key=None): def verify(self, msg, sig, key=None):
if key is None: if key is None:
key = self.key key = self.key
h = self.digest.new(msg) try:
verifier = PKCS1_v1_5.new(key) if isinstance(key, rsa.RSAPrivateKey):
return verifier.verify(h, sig) key = key.public_key()
key.verify(sig, msg, PKCS1v15(), self.digest)
return True
except InvalidSignature:
return False
SIGNER_ALGS = { SIGNER_ALGS = {
SIG_RSA_SHA1: RSASigner(SHA), SIG_RSA_SHA1: RSASigner(hashes.SHA1()),
SIG_RSA_SHA224: RSASigner(SHA224), SIG_RSA_SHA224: RSASigner(hashes.SHA224()),
SIG_RSA_SHA256: RSASigner(SHA256), SIG_RSA_SHA256: RSASigner(hashes.SHA256()),
SIG_RSA_SHA384: RSASigner(SHA384), SIG_RSA_SHA384: RSASigner(hashes.SHA384()),
SIG_RSA_SHA512: RSASigner(SHA512), SIG_RSA_SHA512: RSASigner(hashes.SHA512()),
} }
REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"] REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"]

View File

@ -19,6 +19,7 @@ except ImportError:
except ImportError: except ImportError:
#noinspection PyUnresolvedReferences #noinspection PyUnresolvedReferences
from elementtree import ElementTree from elementtree import ElementTree
import defusedxml.ElementTree
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -133,7 +134,7 @@ def parse_soap_enveloped_saml_thingy(text, expected_tags):
:param expected_tags: What the tag of the SAML thingy is expected to be. :param expected_tags: What the tag of the SAML thingy is expected to be.
:return: SAML thingy as a string :return: SAML thingy as a string
""" """
envelope = ElementTree.fromstring(text) envelope = defusedxml.ElementTree.fromstring(text)
# Make sure it's a SOAP message # Make sure it's a SOAP message
assert envelope.tag == '{%s}Envelope' % soapenv.NAMESPACE assert envelope.tag == '{%s}Envelope' % soapenv.NAMESPACE
@ -183,7 +184,7 @@ def class_instances_from_soap_enveloped_saml_thingies(text, modules):
:return: The body and headers as class instances :return: The body and headers as class instances
""" """
try: try:
envelope = ElementTree.fromstring(text) envelope = defusedxml.ElementTree.fromstring(text)
except Exception as exc: except Exception as exc:
raise XmlParseError("%s" % exc) raise XmlParseError("%s" % exc)
@ -209,7 +210,7 @@ def open_soap_envelope(text):
:return: dictionary with two keys "body"/"header" :return: dictionary with two keys "body"/"header"
""" """
try: try:
envelope = ElementTree.fromstring(text) envelope = defusedxml.ElementTree.fromstring(text)
except Exception as exc: except Exception as exc:
raise XmlParseError("%s" % exc) raise XmlParseError("%s" % exc)

View File

@ -0,0 +1,64 @@
from pathutils import full_path
from pathutils import xmlsec_path
CONFIG = {
"entityid": "urn:mace:example.com:saml:roland:sp",
"name": "urn:mace:example.com:saml:roland:sp",
"description": "My own SP",
"service": {
"sp": {
"endpoints": {
"assertion_consumer_service": [
"http://lingon.catalogix.se:8087/"],
},
"required_attributes": ["surName", "givenName", "mail"],
"optional_attributes": ["title"],
"idp": ["urn:mace:example.com:saml:roland:idp"],
"name_id_format": "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent",
"name_id_format_allow_create": "true"
}
},
"debug": 1,
"key_file": full_path("test.key"),
"cert_file": full_path("test.pem"),
"encryption_keypairs": [{"key_file": full_path("test_1.key"), "cert_file": full_path("test_1.crt")},
{"key_file": full_path("test_2.key"), "cert_file": full_path("test_2.crt")}],
"ca_certs": full_path("cacerts.txt"),
"xmlsec_binary": xmlsec_path,
"metadata": [{
"class": "saml2.mdstore.MetaDataFile",
"metadata": [(full_path("idp.xml"), ), (full_path("vo_metadata.xml"), )],
}],
"virtual_organization": {
"urn:mace:example.com:it:tek": {
"nameid_format": "urn:oid:1.3.6.1.4.1.1466.115.121.1.15-NameID",
"common_identifier": "umuselin",
}
},
"subject_data": "subject_data.db",
"accepted_time_diff": 60,
"attribute_map_dir": full_path("attributemaps"),
"valid_for": 6,
"organization": {
"name": ("AB Exempel", "se"),
"display_name": ("AB Exempel", "se"),
"url": "http://www.example.org",
},
"contact_person": [{
"given_name": "Roland",
"sur_name": "Hedberg",
"telephone_number": "+46 70 100 0000",
"email_address": ["tech@eample.com",
"tech@example.org"],
"contact_type": "technical"
},
],
"logger": {
"rotating": {
"filename": full_path("sp.log"),
"maxBytes": 100000,
"backupCount": 5,
},
"loglevel": "info",
}
}

View File

@ -17,6 +17,7 @@ except ImportError:
import cElementTree as ElementTree import cElementTree as ElementTree
except ImportError: except ImportError:
from elementtree import ElementTree from elementtree import ElementTree
from defusedxml.common import EntitiesForbidden
ITEMS = { ITEMS = {
NameID: ["""<?xml version="1.0" encoding="utf-8"?> NameID: ["""<?xml version="1.0" encoding="utf-8"?>
@ -166,6 +167,19 @@ def test_create_class_from_xml_string_wrong_class_spec():
assert kl == None assert kl == None
def test_create_class_from_xml_string_xxe():
xml = """<?xml version="1.0"?>
<!DOCTYPE lolz [
<!ENTITY lol "lol">
<!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
]>
<lolz>&lol1;</lolz>
"""
with raises(EntitiesForbidden) as err:
create_class_from_xml_string(NameID, xml)
def test_ee_1(): def test_ee_1():
ee = saml2.extension_element_from_string( ee = saml2.extension_element_from_string(
"""<?xml version='1.0' encoding='UTF-8'?><foo>bar</foo>""") """<?xml version='1.0' encoding='UTF-8'?><foo>bar</foo>""")
@ -454,6 +468,19 @@ def test_ee_7():
assert nid.text.strip() == "http://federationX.org" assert nid.text.strip() == "http://federationX.org"
def test_ee_xxe():
xml = """<?xml version="1.0"?>
<!DOCTYPE lolz [
<!ENTITY lol "lol">
<!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
]>
<lolz>&lol1;</lolz>
"""
with raises(EntitiesForbidden):
saml2.extension_element_from_string(xml)
def test_extension_element_loadd(): def test_extension_element_loadd():
ava = {'attributes': {}, ava = {'attributes': {},
'tag': 'ExternalEntityAttributeAuthority', 'tag': 'ExternalEntityAttributeAuthority',

View File

@ -12,9 +12,13 @@ except ImportError:
import cElementTree as ElementTree import cElementTree as ElementTree
except ImportError: except ImportError:
from elementtree import ElementTree from elementtree import ElementTree
from defusedxml.common import EntitiesForbidden
from pytest import raises
import saml2.samlp as samlp import saml2.samlp as samlp
from saml2.samlp import NAMESPACE as SAMLP_NAMESPACE from saml2.samlp import NAMESPACE as SAMLP_NAMESPACE
from saml2 import soap
NAMESPACE = "http://schemas.xmlsoap.org/soap/envelope/" NAMESPACE = "http://schemas.xmlsoap.org/soap/envelope/"
@ -66,3 +70,42 @@ def test_make_soap_envelope():
assert len(body) == 1 assert len(body) == 1
saml_part = body[0] saml_part = body[0]
assert saml_part.tag == '{%s}AuthnRequest' % SAMLP_NAMESPACE assert saml_part.tag == '{%s}AuthnRequest' % SAMLP_NAMESPACE
def test_parse_soap_enveloped_saml_thingy_xxe():
xml = """<?xml version="1.0"?>
<!DOCTYPE lolz [
<!ENTITY lol "lol">
<!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
]>
<lolz>&lol1;</lolz>
"""
with raises(EntitiesForbidden):
soap.parse_soap_enveloped_saml_thingy(xml, None)
def test_class_instances_from_soap_enveloped_saml_thingies_xxe():
xml = """<?xml version="1.0"?>
<!DOCTYPE lolz [
<!ENTITY lol "lol">
<!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
]>
<lolz>&lol1;</lolz>
"""
with raises(soap.XmlParseError):
soap.class_instances_from_soap_enveloped_saml_thingies(xml, None)
def test_open_soap_envelope_xxe():
xml = """<?xml version="1.0"?>
<!DOCTYPE lolz [
<!ENTITY lol "lol">
<!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
]>
<lolz>&lol1;</lolz>
"""
with raises(soap.XmlParseError):
soap.open_soap_envelope(xml)

View File

@ -7,6 +7,7 @@ import six
from future.backports.urllib.parse import parse_qs from future.backports.urllib.parse import parse_qs
from future.backports.urllib.parse import urlencode from future.backports.urllib.parse import urlencode
from future.backports.urllib.parse import urlparse from future.backports.urllib.parse import urlparse
from pytest import raises
from saml2.argtree import add_path from saml2.argtree import add_path
from saml2.cert import OpenSSLWrapper from saml2.cert import OpenSSLWrapper
@ -25,6 +26,7 @@ from saml2.assertion import Assertion
from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.authn_context import INTERNETPROTOCOLPASSWORD
from saml2.client import Saml2Client from saml2.client import Saml2Client
from saml2.config import SPConfig from saml2.config import SPConfig
from saml2.pack import parse_soap_enveloped_saml
from saml2.response import LogoutResponse from saml2.response import LogoutResponse
from saml2.saml import NAMEID_FORMAT_PERSISTENT, EncryptedAssertion, Advice from saml2.saml import NAMEID_FORMAT_PERSISTENT, EncryptedAssertion, Advice
from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.saml import NAMEID_FORMAT_TRANSIENT
@ -38,6 +40,8 @@ from saml2.s_utils import do_attribute_statement
from saml2.s_utils import factory from saml2.s_utils import factory
from saml2.time_util import in_a_while, a_while_ago from saml2.time_util import in_a_while, a_while_ago
from defusedxml.common import EntitiesForbidden
from fakeIDP import FakeIDP from fakeIDP import FakeIDP
from fakeIDP import unpack_form from fakeIDP import unpack_form
from pathutils import full_path from pathutils import full_path
@ -276,6 +280,26 @@ class TestClient:
assert nid_policy.allow_create == "false" assert nid_policy.allow_create == "false"
assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT assert nid_policy.format == saml.NAMEID_FORMAT_TRANSIENT
def test_create_auth_request_nameid_policy_allow_create(self):
conf = config.SPConfig()
conf.load_file("sp_conf_nameidpolicy")
client = Saml2Client(conf)
ar_str = "%s" % client.create_authn_request(
"http://www.example.com/sso", message_id="id1")[1]
ar = samlp.authn_request_from_string(ar_str)
print(ar)
assert ar.assertion_consumer_service_url == ("http://lingon.catalogix"
".se:8087/")
assert ar.destination == "http://www.example.com/sso"
assert ar.protocol_binding == BINDING_HTTP_POST
assert ar.version == "2.0"
assert ar.provider_name == "urn:mace:example.com:saml:roland:sp"
assert ar.issuer.text == "urn:mace:example.com:saml:roland:sp"
nid_policy = ar.name_id_policy
assert nid_policy.allow_create == "true"
assert nid_policy.format == saml.NAMEID_FORMAT_PERSISTENT
def test_create_auth_request_vo(self): def test_create_auth_request_vo(self):
assert list(self.client.config.vorg.keys()) == [ assert list(self.client.config.vorg.keys()) == [
"urn:mace:example.com:it:tek"] "urn:mace:example.com:it:tek"]
@ -1552,6 +1576,17 @@ class TestClientWithDummy():
'http://www.example.com/login' 'http://www.example.com/login'
assert ac.authn_context_class_ref.text == INTERNETPROTOCOLPASSWORD assert ac.authn_context_class_ref.text == INTERNETPROTOCOLPASSWORD
def test_parse_soap_enveloped_saml_xxe():
xml = """<?xml version="1.0"?>
<!DOCTYPE lolz [
<!ENTITY lol "lol">
<!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
]>
<lolz>&lol1;</lolz>
"""
with raises(EntitiesForbidden):
parse_soap_enveloped_saml(xml, None)
# if __name__ == "__main__": # if __name__ == "__main__":
# tc = TestClient() # tc = TestClient()

View File

@ -1,3 +1,5 @@
mock==2.0.0
pymongo==3.0.1 pymongo==3.0.1
pytest==3.0.3
responses==0.5.0 responses==0.5.0
mock pyasn1==0.2.3

View File

@ -2,6 +2,5 @@
envlist = py27,py34 envlist = py27,py34
[testenv] [testenv]
deps = pytest deps = -rtests/test_requirements.txt
-rtests/test_requirements.txt
commands = py.test tests/ commands = py.test tests/