diff --git a/quantum/common/extensions.py b/quantum/common/extensions.py index 44d52a8b8..7d4672699 100644 --- a/quantum/common/extensions.py +++ b/quantum/common/extensions.py @@ -352,7 +352,8 @@ class ExtensionManager(object): LOG.debug(_('Ext description: %s'), extension.get_description()) LOG.debug(_('Ext namespace: %s'), extension.get_namespace()) LOG.debug(_('Ext updated: %s'), extension.get_updated()) - return self._plugin_supports(extension) + return (self._plugin_supports(extension) and + self._plugin_implements_interface(extension)) except AttributeError as ex: LOG.exception(_("Exception loading extension: %s"), unicode(ex)) return False @@ -361,6 +362,19 @@ class ExtensionManager(object): return (hasattr(self.plugin, "supports_extension") and self.plugin.supports_extension(extension)) + def _plugin_implements_interface(self, extension): + if not hasattr(extension, "get_plugin_interface"): + return True + interface = extension.get_plugin_interface() + expected_methods = self._get_public_methods(interface) + implemented_methods = self._get_public_methods(self.plugin.__class__) + missing_methods = set(expected_methods) - set(implemented_methods) + return len(missing_methods) == 0 + + def _get_public_methods(self, klass): + return filter(lambda name: not(name.startswith("_")), + klass.__dict__.keys()) + def _load_all_extensions(self): """Load extensions from the configured path. diff --git a/tests/unit/test_extensions.py b/tests/unit/test_extensions.py index 0fe557f0d..d806469a0 100644 --- a/tests/unit/test_extensions.py +++ b/tests/unit/test_extensions.py @@ -101,34 +101,48 @@ class StubPlugin(object): return extension.get_alias() in self.supported_extensions +class ExtensionExpectingPluginInterface(StubExtension): + + def get_plugin_interface(self): + return PluginInterface + + +class PluginInterface(object): + + def get_foo(self, bar=None): + pass + + class ExtensionManagerTest(unittest.TestCase): + def setUp(self): + self.ext_mgr = setup_extensions_middleware().ext_mgr + super(ExtensionManagerTest, self).setUp() + def test_invalid_extensions_are_not_registered(self): class InvalidExtension(object): def get_alias(self): return "invalid_extension" - ext_mgr = setup_extensions_middleware().ext_mgr - ext_mgr.add_extension(InvalidExtension()) - ext_mgr.add_extension(StubExtension("valid_extension")) + self.ext_mgr.add_extension(InvalidExtension()) + self.ext_mgr.add_extension(StubExtension("valid_extension")) - self.assertTrue('valid_extension' in ext_mgr.extensions) - self.assertFalse('invalid_extension' in ext_mgr.extensions) + self.assertTrue('valid_extension' in self.ext_mgr.extensions) + self.assertFalse('invalid_extension' in self.ext_mgr.extensions) def test_unsupported_extensions_are_not_loaded(self): - ext_mgr = setup_extensions_middleware().ext_mgr - ext_mgr.plugin = StubPlugin(supported_extensions=["e1", "e3"]) + self.ext_mgr.plugin = StubPlugin(supported_extensions=["e1", "e3"]) - ext_mgr.add_extension(StubExtension("e1")) - ext_mgr.add_extension(StubExtension("e2")) - ext_mgr.add_extension(StubExtension("e3")) + self.ext_mgr.add_extension(StubExtension("e1")) + self.ext_mgr.add_extension(StubExtension("e2")) + self.ext_mgr.add_extension(StubExtension("e3")) - self.assertTrue("e1" in ext_mgr.extensions) - self.assertFalse("e2" in ext_mgr.extensions) - self.assertTrue("e3" in ext_mgr.extensions) + self.assertTrue("e1" in self.ext_mgr.extensions) + self.assertFalse("e2" in self.ext_mgr.extensions) + self.assertTrue("e3" in self.ext_mgr.extensions) - def test_extensions_are_not_loaded_for_extensions_unaware_plugins(self): + def test_extensions_are_not_loaded_for_plugins_unaware_of_extensions(self): class ExtensionUnawarePlugin(object): """ This plugin does not implement supports_extension method. @@ -136,12 +150,54 @@ class ExtensionManagerTest(unittest.TestCase): """ pass - ext_mgr = setup_extensions_middleware().ext_mgr - ext_mgr.plugin = ExtensionUnawarePlugin() + self.ext_mgr.plugin = ExtensionUnawarePlugin() + self.ext_mgr.add_extension(StubExtension("e1")) - ext_mgr.add_extension(StubExtension("e1")) + self.assertFalse("e1" in self.ext_mgr.extensions) - self.assertFalse("e1" in ext_mgr.extensions) + def test_extensions_not_loaded_for_plugin_without_expected_interface(self): + + class PluginWithoutExpectedInterface(object): + """ + Plugin does not implement get_foo method as expected by extension + """ + def supports_extension(self, true): + return true + + self.ext_mgr.plugin = PluginWithoutExpectedInterface() + self.ext_mgr.add_extension(ExtensionExpectingPluginInterface("e1")) + + self.assertFalse("e1" in self.ext_mgr.extensions) + + def test_extensions_are_loaded_for_plugin_with_expected_interface(self): + + class PluginWithExpectedInterface(object): + """ + This Plugin implements get_foo method as expected by extension + """ + def supports_extension(self, true): + return true + + def get_foo(self, bar=None): + pass + + self.ext_mgr.plugin = PluginWithExpectedInterface() + self.ext_mgr.add_extension(ExtensionExpectingPluginInterface("e1")) + + self.assertTrue("e1" in self.ext_mgr.extensions) + + def test_extensions_expecting_quantum_plugin_interface_are_loaded(self): + class ExtensionForQuamtumPluginInterface(StubExtension): + """ + This Extension does not implement get_plugin_interface method. + This will work with any plugin implementing QuantumPluginBase + """ + pass + + self.ext_mgr.plugin = StubPlugin(supported_extensions=["e1"]) + self.ext_mgr.add_extension(ExtensionForQuamtumPluginInterface("e1")) + + self.assertTrue("e1" in self.ext_mgr.extensions) class ActionExtensionTest(unittest.TestCase):