This branch allows LdapDriver to reconnect to LDAP server if connection is lost.

This commit is contained in:
Yuriy Taraday
2011-06-28 15:32:02 +00:00
committed by Tarmac
3 changed files with 71 additions and 2 deletions

View File

@@ -100,6 +100,11 @@ class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable=C0103
pass pass
class SERVER_DOWN(Exception): # pylint: disable=C0103
"""Duplicate exception class from real LDAP module."""
pass
def initialize(_uri): def initialize(_uri):
"""Opens a fake connection with an LDAP server.""" """Opens a fake connection with an LDAP server."""
return FakeLDAP() return FakeLDAP()
@@ -202,25 +207,38 @@ def _to_json(unencoded):
return json.dumps(list(unencoded)) return json.dumps(list(unencoded))
server_fail = False
class FakeLDAP(object): class FakeLDAP(object):
"""Fake LDAP connection.""" """Fake LDAP connection."""
def simple_bind_s(self, dn, password): def simple_bind_s(self, dn, password):
"""This method is ignored, but provided for compatibility.""" """This method is ignored, but provided for compatibility."""
if server_fail:
raise SERVER_DOWN
pass pass
def unbind_s(self): def unbind_s(self):
"""This method is ignored, but provided for compatibility.""" """This method is ignored, but provided for compatibility."""
if server_fail:
raise SERVER_DOWN
pass pass
def add_s(self, dn, attr): def add_s(self, dn, attr):
"""Add an object with the specified attributes at dn.""" """Add an object with the specified attributes at dn."""
if server_fail:
raise SERVER_DOWN
key = "%s%s" % (self.__prefix, dn) key = "%s%s" % (self.__prefix, dn)
value_dict = dict([(k, _to_json(v)) for k, v in attr]) value_dict = dict([(k, _to_json(v)) for k, v in attr])
Store.instance().hmset(key, value_dict) Store.instance().hmset(key, value_dict)
def delete_s(self, dn): def delete_s(self, dn):
"""Remove the ldap object at specified dn.""" """Remove the ldap object at specified dn."""
if server_fail:
raise SERVER_DOWN
Store.instance().delete("%s%s" % (self.__prefix, dn)) Store.instance().delete("%s%s" % (self.__prefix, dn))
def modify_s(self, dn, attrs): def modify_s(self, dn, attrs):
@@ -232,6 +250,9 @@ class FakeLDAP(object):
([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value) ([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value)
""" """
if server_fail:
raise SERVER_DOWN
store = Store.instance() store = Store.instance()
key = "%s%s" % (self.__prefix, dn) key = "%s%s" % (self.__prefix, dn)
@@ -255,6 +276,9 @@ class FakeLDAP(object):
fields -- fields to return. Returns all fields if not specified fields -- fields to return. Returns all fields if not specified
""" """
if server_fail:
raise SERVER_DOWN
if scope != SCOPE_BASE and scope != SCOPE_SUBTREE: if scope != SCOPE_BASE and scope != SCOPE_SUBTREE:
raise NotImplementedError(str(scope)) raise NotImplementedError(str(scope))
store = Store.instance() store = Store.instance()

View File

@@ -101,6 +101,41 @@ def sanitize(fn):
return _wrapped return _wrapped
class LDAPWrapper(object):
def __init__(self, ldap, url, user, password):
self.ldap = ldap
self.url = url
self.user = user
self.password = password
self.conn = None
def __wrap_reconnect(f):
def inner(self, *args, **kwargs):
if self.conn is None:
self.connect()
return f(self.conn)(*args, **kwargs)
else:
try:
return f(self.conn)(*args, **kwargs)
except self.ldap.SERVER_DOWN:
self.connect()
return f(self.conn)(*args, **kwargs)
return inner
def connect(self):
try:
self.conn = self.ldap.initialize(self.url)
self.conn.simple_bind_s(self.user, self.password)
except self.ldap.SERVER_DOWN:
self.conn = None
raise
search_s = __wrap_reconnect(lambda conn: conn.search_s)
add_s = __wrap_reconnect(lambda conn: conn.add_s)
delete_s = __wrap_reconnect(lambda conn: conn.delete_s)
modify_s = __wrap_reconnect(lambda conn: conn.modify_s)
class LdapDriver(object): class LdapDriver(object):
"""Ldap Auth driver """Ldap Auth driver
@@ -124,8 +159,8 @@ class LdapDriver(object):
LdapDriver.project_objectclass = 'novaProject' LdapDriver.project_objectclass = 'novaProject'
self.__cache = None self.__cache = None
if LdapDriver.conn is None: if LdapDriver.conn is None:
LdapDriver.conn = self.ldap.initialize(FLAGS.ldap_url) LdapDriver.conn = LDAPWrapper(self.ldap, FLAGS.ldap_url,
LdapDriver.conn.simple_bind_s(FLAGS.ldap_user_dn, FLAGS.ldap_user_dn,
FLAGS.ldap_password) FLAGS.ldap_password)
if LdapDriver.mc is None: if LdapDriver.mc is None:
LdapDriver.mc = memcache.Client(FLAGS.memcached_servers, debug=0) LdapDriver.mc = memcache.Client(FLAGS.memcached_servers, debug=0)

View File

@@ -25,6 +25,7 @@ from nova import log as logging
from nova import test from nova import test
from nova.auth import manager from nova.auth import manager
from nova.api.ec2 import cloud from nova.api.ec2 import cloud
from nova.auth import fakeldap
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
LOG = logging.getLogger('nova.tests.auth_unittest') LOG = logging.getLogger('nova.tests.auth_unittest')
@@ -369,6 +370,15 @@ class _AuthManagerBaseTestCase(test.TestCase):
class AuthManagerLdapTestCase(_AuthManagerBaseTestCase): class AuthManagerLdapTestCase(_AuthManagerBaseTestCase):
auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
def test_reconnect_on_server_failure(self):
self.manager.get_users()
fakeldap.server_fail = True
try:
self.assertRaises(fakeldap.SERVER_DOWN, self.manager.get_users)
finally:
fakeldap.server_fail = False
self.manager.get_users()
class AuthManagerDbTestCase(_AuthManagerBaseTestCase): class AuthManagerDbTestCase(_AuthManagerBaseTestCase):
auth_driver = 'nova.auth.dbdriver.DbDriver' auth_driver = 'nova.auth.dbdriver.DbDriver'