This branch allows LdapDriver to reconnect to LDAP server if connection is lost.
This commit is contained in:
@@ -100,6 +100,11 @@ class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable=C0103
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SERVER_DOWN(Exception): # pylint: disable=C0103
|
||||||
|
"""Duplicate exception class from real LDAP module."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def initialize(_uri):
|
def initialize(_uri):
|
||||||
"""Opens a fake connection with an LDAP server."""
|
"""Opens a fake connection with an LDAP server."""
|
||||||
return FakeLDAP()
|
return FakeLDAP()
|
||||||
@@ -202,25 +207,38 @@ def _to_json(unencoded):
|
|||||||
return json.dumps(list(unencoded))
|
return json.dumps(list(unencoded))
|
||||||
|
|
||||||
|
|
||||||
|
server_fail = False
|
||||||
|
|
||||||
|
|
||||||
class FakeLDAP(object):
|
class FakeLDAP(object):
|
||||||
"""Fake LDAP connection."""
|
"""Fake LDAP connection."""
|
||||||
|
|
||||||
def simple_bind_s(self, dn, password):
|
def simple_bind_s(self, dn, password):
|
||||||
"""This method is ignored, but provided for compatibility."""
|
"""This method is ignored, but provided for compatibility."""
|
||||||
|
if server_fail:
|
||||||
|
raise SERVER_DOWN
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def unbind_s(self):
|
def unbind_s(self):
|
||||||
"""This method is ignored, but provided for compatibility."""
|
"""This method is ignored, but provided for compatibility."""
|
||||||
|
if server_fail:
|
||||||
|
raise SERVER_DOWN
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_s(self, dn, attr):
|
def add_s(self, dn, attr):
|
||||||
"""Add an object with the specified attributes at dn."""
|
"""Add an object with the specified attributes at dn."""
|
||||||
|
if server_fail:
|
||||||
|
raise SERVER_DOWN
|
||||||
|
|
||||||
key = "%s%s" % (self.__prefix, dn)
|
key = "%s%s" % (self.__prefix, dn)
|
||||||
value_dict = dict([(k, _to_json(v)) for k, v in attr])
|
value_dict = dict([(k, _to_json(v)) for k, v in attr])
|
||||||
Store.instance().hmset(key, value_dict)
|
Store.instance().hmset(key, value_dict)
|
||||||
|
|
||||||
def delete_s(self, dn):
|
def delete_s(self, dn):
|
||||||
"""Remove the ldap object at specified dn."""
|
"""Remove the ldap object at specified dn."""
|
||||||
|
if server_fail:
|
||||||
|
raise SERVER_DOWN
|
||||||
|
|
||||||
Store.instance().delete("%s%s" % (self.__prefix, dn))
|
Store.instance().delete("%s%s" % (self.__prefix, dn))
|
||||||
|
|
||||||
def modify_s(self, dn, attrs):
|
def modify_s(self, dn, attrs):
|
||||||
@@ -232,6 +250,9 @@ class FakeLDAP(object):
|
|||||||
([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value)
|
([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if server_fail:
|
||||||
|
raise SERVER_DOWN
|
||||||
|
|
||||||
store = Store.instance()
|
store = Store.instance()
|
||||||
key = "%s%s" % (self.__prefix, dn)
|
key = "%s%s" % (self.__prefix, dn)
|
||||||
|
|
||||||
@@ -255,6 +276,9 @@ class FakeLDAP(object):
|
|||||||
fields -- fields to return. Returns all fields if not specified
|
fields -- fields to return. Returns all fields if not specified
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if server_fail:
|
||||||
|
raise SERVER_DOWN
|
||||||
|
|
||||||
if scope != SCOPE_BASE and scope != SCOPE_SUBTREE:
|
if scope != SCOPE_BASE and scope != SCOPE_SUBTREE:
|
||||||
raise NotImplementedError(str(scope))
|
raise NotImplementedError(str(scope))
|
||||||
store = Store.instance()
|
store = Store.instance()
|
||||||
|
|||||||
@@ -101,6 +101,41 @@ def sanitize(fn):
|
|||||||
return _wrapped
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class LDAPWrapper(object):
|
||||||
|
def __init__(self, ldap, url, user, password):
|
||||||
|
self.ldap = ldap
|
||||||
|
self.url = url
|
||||||
|
self.user = user
|
||||||
|
self.password = password
|
||||||
|
self.conn = None
|
||||||
|
|
||||||
|
def __wrap_reconnect(f):
|
||||||
|
def inner(self, *args, **kwargs):
|
||||||
|
if self.conn is None:
|
||||||
|
self.connect()
|
||||||
|
return f(self.conn)(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return f(self.conn)(*args, **kwargs)
|
||||||
|
except self.ldap.SERVER_DOWN:
|
||||||
|
self.connect()
|
||||||
|
return f(self.conn)(*args, **kwargs)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
try:
|
||||||
|
self.conn = self.ldap.initialize(self.url)
|
||||||
|
self.conn.simple_bind_s(self.user, self.password)
|
||||||
|
except self.ldap.SERVER_DOWN:
|
||||||
|
self.conn = None
|
||||||
|
raise
|
||||||
|
|
||||||
|
search_s = __wrap_reconnect(lambda conn: conn.search_s)
|
||||||
|
add_s = __wrap_reconnect(lambda conn: conn.add_s)
|
||||||
|
delete_s = __wrap_reconnect(lambda conn: conn.delete_s)
|
||||||
|
modify_s = __wrap_reconnect(lambda conn: conn.modify_s)
|
||||||
|
|
||||||
|
|
||||||
class LdapDriver(object):
|
class LdapDriver(object):
|
||||||
"""Ldap Auth driver
|
"""Ldap Auth driver
|
||||||
|
|
||||||
@@ -124,8 +159,8 @@ class LdapDriver(object):
|
|||||||
LdapDriver.project_objectclass = 'novaProject'
|
LdapDriver.project_objectclass = 'novaProject'
|
||||||
self.__cache = None
|
self.__cache = None
|
||||||
if LdapDriver.conn is None:
|
if LdapDriver.conn is None:
|
||||||
LdapDriver.conn = self.ldap.initialize(FLAGS.ldap_url)
|
LdapDriver.conn = LDAPWrapper(self.ldap, FLAGS.ldap_url,
|
||||||
LdapDriver.conn.simple_bind_s(FLAGS.ldap_user_dn,
|
FLAGS.ldap_user_dn,
|
||||||
FLAGS.ldap_password)
|
FLAGS.ldap_password)
|
||||||
if LdapDriver.mc is None:
|
if LdapDriver.mc is None:
|
||||||
LdapDriver.mc = memcache.Client(FLAGS.memcached_servers, debug=0)
|
LdapDriver.mc = memcache.Client(FLAGS.memcached_servers, debug=0)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from nova import log as logging
|
|||||||
from nova import test
|
from nova import test
|
||||||
from nova.auth import manager
|
from nova.auth import manager
|
||||||
from nova.api.ec2 import cloud
|
from nova.api.ec2 import cloud
|
||||||
|
from nova.auth import fakeldap
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
LOG = logging.getLogger('nova.tests.auth_unittest')
|
LOG = logging.getLogger('nova.tests.auth_unittest')
|
||||||
@@ -369,6 +370,15 @@ class _AuthManagerBaseTestCase(test.TestCase):
|
|||||||
class AuthManagerLdapTestCase(_AuthManagerBaseTestCase):
|
class AuthManagerLdapTestCase(_AuthManagerBaseTestCase):
|
||||||
auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
|
auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
|
||||||
|
|
||||||
|
def test_reconnect_on_server_failure(self):
|
||||||
|
self.manager.get_users()
|
||||||
|
fakeldap.server_fail = True
|
||||||
|
try:
|
||||||
|
self.assertRaises(fakeldap.SERVER_DOWN, self.manager.get_users)
|
||||||
|
finally:
|
||||||
|
fakeldap.server_fail = False
|
||||||
|
self.manager.get_users()
|
||||||
|
|
||||||
|
|
||||||
class AuthManagerDbTestCase(_AuthManagerBaseTestCase):
|
class AuthManagerDbTestCase(_AuthManagerBaseTestCase):
|
||||||
auth_driver = 'nova.auth.dbdriver.DbDriver'
|
auth_driver = 'nova.auth.dbdriver.DbDriver'
|
||||||
|
|||||||
Reference in New Issue
Block a user