add closing method to make sure to close ident db connection

This commit is contained in:
Erick Tryzelaar
2014-10-24 14:57:34 -07:00
parent bf9747cd97
commit ff5cb7d8ee
17 changed files with 533 additions and 513 deletions

View File

@@ -334,4 +334,5 @@ class IdentDB(object):
return name_id return name_id
def close(self): def close(self):
if hasattr(self.db, 'close'):
self.db.close() self.db.close()

View File

@@ -145,6 +145,7 @@ class Server(Entity):
raise Exception("Couldn't open identity database: %s" % raise Exception("Couldn't open identity database: %s" %
(dbspec,)) (dbspec,))
try:
_domain = self.config.getattr("domain", "idp") _domain = self.config.getattr("domain", "idp")
if _domain: if _domain:
self.ident.domain = _domain self.ident.domain = _domain
@@ -167,6 +168,9 @@ class Server(Entity):
collection="eptid") collection="eptid")
else: else:
self.eptid = Eptid(secret) self.eptid = Eptid(secret)
except Exception:
self.ident.close()
raise
def wants(self, sp_entity_id, index=None): def wants(self, sp_entity_id, index=None):
""" Returns what attributes the SP requires and which are optional """ Returns what attributes the SP requires and which are optional
@@ -681,3 +685,6 @@ class Server(Entity):
soap_envelope = soapenv.Envelope(header=header, body=body) soap_envelope = soapenv.Envelope(header=header, body=body)
return "%s" % soap_envelope return "%s" % soap_envelope
def close(self):
self.ident.close()

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from saml2 import saml, sigver from saml2 import saml, sigver
from saml2 import md from saml2 import md
from saml2 import config from saml2 import config
@@ -150,8 +151,7 @@ def test_filter_ava5():
def test_idp_policy_filter(): def test_idp_policy_filter():
idp = Server("idp_conf_ec") with closing(Server("idp_conf_ec")) as idp:
ava = {"givenName": ["Derek"], "sn": ["Jeter"], ava = {"givenName": ["Derek"], "sn": ["Jeter"],
"mail": ["derek@nyy.mlb.com"], "c": ["USA"], "mail": ["derek@nyy.mlb.com"], "c": ["USA"],
"eduPersonTargetedID": "foo!bar!xyz", "eduPersonTargetedID": "foo!bar!xyz",

View File

@@ -1,6 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from contextlib import closing
from saml2 import config from saml2 import config
from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.authn_context import INTERNETPROTOCOLPASSWORD
@@ -34,7 +36,7 @@ AUTHN = {
class TestResponse: class TestResponse:
def setup_class(self): def setup_class(self):
server = Server("idp_conf") with closing(Server("idp_conf")) as server:
name_id = server.ident.transient_nameid( name_id = server.ident.transient_nameid(
"urn:mace:example.com:saml:roland:sp", "id12") "urn:mace:example.com:saml:roland:sp", "id12")

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.authn_context import INTERNETPROTOCOLPASSWORD
from saml2.server import Server from saml2.server import Server
from saml2.sigver import pre_encryption_part, ASSERT_XPATH, EncryptError from saml2.sigver import pre_encryption_part, ASSERT_XPATH, EncryptError
@@ -30,7 +31,7 @@ def test_pre_enc():
def test_reshuffle_response(): def test_reshuffle_response():
server = Server("idp_conf") with closing(Server("idp_conf")) as server:
name_id = server.ident.transient_nameid( name_id = server.ident.transient_nameid(
"urn:mace:example.com:saml:roland:sp", "id12") "urn:mace:example.com:saml:roland:sp", "id12")
@@ -45,7 +46,7 @@ def test_reshuffle_response():
def test_enc1(): def test_enc1():
server = Server("idp_conf") with closing(Server("idp_conf")) as server:
name_id = server.ident.transient_nameid( name_id = server.ident.transient_nameid(
"urn:mace:example.com:saml:roland:sp", "id12") "urn:mace:example.com:saml:roland:sp", "id12")
@@ -82,7 +83,7 @@ def test_enc1():
def test_enc2(): def test_enc2():
crypto = CryptoBackendXmlSec1(xmlsec_path) crypto = CryptoBackendXmlSec1(xmlsec_path)
server = Server("idp_conf") with closing(Server("idp_conf")) as server:
name_id = server.ident.transient_nameid( name_id = server.ident.transient_nameid(
"urn:mace:example.com:saml:roland:sp", "id12") "urn:mace:example.com:saml:roland:sp", "id12")

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from contextlib import closing
from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.authn_context import INTERNETPROTOCOLPASSWORD
from saml2.server import Server from saml2.server import Server
@@ -28,7 +29,7 @@ AUTHN = {
class TestAuthnResponse: class TestAuthnResponse:
def setup_class(self): def setup_class(self):
server = Server(dotname("idp_conf")) with closing(Server(dotname("idp_conf"))) as server:
name_id = server.ident.transient_nameid( name_id = server.ident.transient_nameid(
"urn:mace:example.com:saml:roland:sp","id12") "urn:mace:example.com:saml:roland:sp","id12")

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import base64 import base64
from contextlib import closing
from urlparse import parse_qs from urlparse import parse_qs
from saml2.sigver import pre_encryption_part from saml2.sigver import pre_encryption_part
from saml2.assertion import Policy from saml2.assertion import Policy
@@ -49,7 +50,7 @@ class TestServer1():
self.client = client.Saml2Client(conf) self.client = client.Saml2Client(conf)
def teardown_class(self): def teardown_class(self):
self.server.ident.close() self.server.close()
def test_issuer(self): def test_issuer(self):
issuer = self.server._issuer() issuer = self.server._issuer()
@@ -419,7 +420,8 @@ class TestServer1():
saml_soap = make_soap_enveloped_saml_thingy(logout_request) saml_soap = make_soap_enveloped_saml_thingy(logout_request)
self.server.ident.close() self.server.ident.close()
idp = Server("idp_soap_conf")
with closing(Server("idp_soap_conf")) as idp:
request = idp.parse_logout_request(saml_soap) request = idp.parse_logout_request(saml_soap)
idp.ident.close() idp.ident.close()
assert request assert request
@@ -436,7 +438,7 @@ class TestServer2():
self.server = Server("restrictive_idp_conf") self.server = Server("restrictive_idp_conf")
def teardown_class(self): def teardown_class(self):
self.server.ident.close() self.server.close()
def test_do_attribute_reponse(self): def test_do_attribute_reponse(self):
aa_policy = self.server.config.getattr("policy", "idp") aa_policy = self.server.config.getattr("policy", "idp")
@@ -487,7 +489,7 @@ def _logout_request(conf_file):
class TestServerLogout(): class TestServerLogout():
def test_1(self): def test_1(self):
server = Server("idp_slo_redirect_conf") with closing(Server("idp_slo_redirect_conf")) as server:
req_id, request = _logout_request("sp_slo_redirect_conf") req_id, request = _logout_request("sp_slo_redirect_conf")
print request print request
bindings = [BINDING_HTTP_REDIRECT] bindings = [BINDING_HTTP_REDIRECT]

View File

@@ -120,6 +120,9 @@ class TestClient:
conf.load_file("server_conf") conf.load_file("server_conf")
self.client = Saml2Client(conf) self.client = Saml2Client(conf)
def teardown_class(self):
self.server.close()
def test_create_attribute_query1(self): def test_create_attribute_query1(self):
req_id, req = self.client.create_attribute_query( req_id, req = self.client.create_attribute_query(
"https://idp.example.com/idp/", "https://idp.example.com/idp/",

View File

@@ -48,6 +48,9 @@ class TestSP():
self.sp = make_plugin("rem", saml_conf="server_conf") self.sp = make_plugin("rem", saml_conf="server_conf")
self.server = Server(config_file="idp_conf") self.server = Server(config_file="idp_conf")
def teardown_class(self):
self.server.close()
def test_setup(self): def test_setup(self):
assert self.sp assert self.sp

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.authn_context import INTERNETPROTOCOLPASSWORD
from saml2.httpbase import set_list2dict from saml2.httpbase import set_list2dict
from saml2.profile.ecp import RelayState from saml2.profile.ecp import RelayState
@@ -41,8 +42,8 @@ def test_complete_flow():
metadata_file=full_path("idp_all.xml")) metadata_file=full_path("idp_all.xml"))
sp = Saml2Client(config_file=dotname("servera_conf")) sp = Saml2Client(config_file=dotname("servera_conf"))
idp = Server(config_file=dotname("idp_all_conf"))
with closing(Server(config_file=dotname("idp_all_conf"))) as idp:
IDP_ENTITY_ID = idp.config.entityid IDP_ENTITY_ID = idp.config.entityid
#SP_ENTITY_ID = sp.config.entityid #SP_ENTITY_ID = sp.config.entityid

View File

@@ -1,4 +1,5 @@
import base64 import base64
from contextlib import closing
from hashlib import sha1 from hashlib import sha1
from urlparse import urlparse from urlparse import urlparse
from urlparse import parse_qs from urlparse import parse_qs
@@ -76,8 +77,7 @@ def test_create_artifact_resolve():
s = sha1(SP) s = sha1(SP)
assert artifact[4:24] == s.digest() assert artifact[4:24] == s.digest()
idp = Server(config_file="idp_all_conf") with closing(Server(config_file="idp_all_conf")) as idp:
typecode = artifact[:2] typecode = artifact[:2]
assert typecode == ARTIFACT_TYPECODE assert typecode == ARTIFACT_TYPECODE
@@ -101,8 +101,8 @@ def test_create_artifact_resolve():
def test_artifact_flow(): def test_artifact_flow():
#SP = 'urn:mace:example.com:saml:roland:sp' #SP = 'urn:mace:example.com:saml:roland:sp'
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf")
with closing(Server(config_file="idp_all_conf")) as idp:
# original request # original request
binding, destination = sp.pick_binding("single_sign_on_service", binding, destination = sp.pick_binding("single_sign_on_service",

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from urlparse import urlparse, parse_qs from urlparse import urlparse, parse_qs
from saml2 import BINDING_SOAP, BINDING_HTTP_POST from saml2 import BINDING_SOAP, BINDING_HTTP_POST
@@ -43,8 +44,7 @@ def get_msg(hinfo, binding):
def test_basic(): def test_basic():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf") with closing(Server(config_file="idp_all_conf")) as idp:
srvs = sp.metadata.authn_query_service(idp.config.entityid) srvs = sp.metadata.authn_query_service(idp.config.entityid)
destination = srvs[0]["location"] destination = srvs[0]["location"]
@@ -62,8 +62,8 @@ def test_basic():
def test_flow(): def test_flow():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf")
with closing(Server(config_file="idp_all_conf")) as idp:
relay_state = "FOO" relay_state = "FOO"
# -- dummy request --- # -- dummy request ---
orig_req = AuthnRequest( orig_req = AuthnRequest(

View File

@@ -1,5 +1,6 @@
__author__ = 'rolandh' __author__ = 'rolandh'
from contextlib import closing
from saml2.client import Saml2Client from saml2.client import Saml2Client
from saml2.saml import NameID, NAMEID_FORMAT_PERSISTENT from saml2.saml import NameID, NAMEID_FORMAT_PERSISTENT
from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.saml import NAMEID_FORMAT_TRANSIENT
@@ -10,8 +11,8 @@ from saml2.samlp import NameIDMappingRequest
def test_base_request(): def test_base_request():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf")
with closing(Server(config_file="idp_all_conf")) as idp:
binding, destination = sp.pick_binding("name_id_mapping_service", binding, destination = sp.pick_binding("name_id_mapping_service",
entity_id=idp.config.entityid) entity_id=idp.config.entityid)
@@ -30,8 +31,8 @@ def test_base_request():
def test_request_response(): def test_request_response():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf")
with closing(Server(config_file="idp_all_conf")) as idp:
binding, destination = sp.pick_binding("name_id_mapping_service", binding, destination = sp.pick_binding("name_id_mapping_service",
entity_id=idp.config.entityid) entity_id=idp.config.entityid)

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from saml2 import BINDING_SOAP from saml2 import BINDING_SOAP
from saml2.samlp import NewID from saml2.samlp import NewID
from saml2.saml import NameID, NAMEID_FORMAT_TRANSIENT from saml2.saml import NameID, NAMEID_FORMAT_TRANSIENT
@@ -9,8 +10,7 @@ __author__ = 'rolandh'
def test_basic(): def test_basic():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf") with closing(Server(config_file="idp_all_conf")) as idp:
# -------- @SP ------------ # -------- @SP ------------
binding, destination = sp.pick_binding("manage_name_id_service", binding, destination = sp.pick_binding("manage_name_id_service",
entity_id=idp.config.entityid) entity_id=idp.config.entityid)
@@ -35,8 +35,7 @@ def test_basic():
def test_flow(): def test_flow():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf") with closing(Server(config_file="idp_all_conf")) as idp:
binding, destination = sp.pick_binding("manage_name_id_service", binding, destination = sp.pick_binding("manage_name_id_service",
entity_id=idp.config.entityid) entity_id=idp.config.entityid)

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from urlparse import parse_qs from urlparse import parse_qs
from urlparse import urlparse from urlparse import urlparse
from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.authn_context import INTERNETPROTOCOLPASSWORD
@@ -46,8 +47,7 @@ def get_msg(hinfo, binding, response=False):
def test_basic_flow(): def test_basic_flow():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp = Server(config_file="idp_all_conf") with closing(Server(config_file="idp_all_conf")) as idp:
# -------- @IDP ------------- # -------- @IDP -------------
relay_state = "FOO" relay_state = "FOO"

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from saml2.pack import http_redirect_message from saml2.pack import http_redirect_message
from saml2.sigver import verify_redirect_signature from saml2.sigver import verify_redirect_signature
from saml2.sigver import import_rsa_key_from_file from saml2.sigver import import_rsa_key_from_file
@@ -12,14 +13,12 @@ from pathutils import dotname
__author__ = 'rolandh' __author__ = 'rolandh'
idp = Server(config_file=dotname("idp_all_conf"))
conf = SPConfig()
conf.load_file(dotname("servera_conf"))
sp = Saml2Client(conf)
def test(): def test():
with closing(Server(config_file=dotname("idp_all_conf"))) as idp:
conf = SPConfig()
conf.load_file(dotname("servera_conf"))
sp = Saml2Client(conf)
srvs = sp.metadata.single_sign_on_service(idp.config.entityid, srvs = sp.metadata.single_sign_on_service(idp.config.entityid,
BINDING_HTTP_REDIRECT) BINDING_HTTP_REDIRECT)

View File

@@ -1,3 +1,4 @@
from contextlib import closing
from saml2 import BINDING_HTTP_POST from saml2 import BINDING_HTTP_POST
from saml2.authn_context import INTERNETPROTOCOLPASSWORD from saml2.authn_context import INTERNETPROTOCOLPASSWORD
from saml2.client import Saml2Client from saml2.client import Saml2Client
@@ -19,9 +20,8 @@ def _eq(l1, l2):
def test_flow(): def test_flow():
sp = Saml2Client(config_file="servera_conf") sp = Saml2Client(config_file="servera_conf")
idp1 = Server(config_file="idp_conf_mdb") with closing(Server(config_file="idp_conf_mdb")) as idp1:
idp2 = Server(config_file="idp_conf_mdb") with closing(Server(config_file="idp_conf_mdb")) as idp2:
# clean out database # clean out database
idp1.ident.mdb.db.drop() idp1.ident.mdb.db.drop()