diff --git a/heat/engine/plugin_manager.py b/heat/engine/plugin_manager.py new file mode 100644 index 0000000000..de8aee6bc8 --- /dev/null +++ b/heat/engine/plugin_manager.py @@ -0,0 +1,115 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import itertools +import sys + +from oslo.config import cfg + +from heat.openstack.common import log +from heat.common import plugin_loader + + +logger = log.getLogger(__name__) + + +class PluginManager(object): + '''A class for managing plugin modules.''' + + def __init__(self, *extra_packages): + '''Initialise the Heat Engine plugin package, and any others. + + The heat.engine.plugins package is always created, if it does not + exist, from the plugin directories specified in the config file, and + searched for modules. In addition, any extra packages specified are + also searched for modules. e.g. + + >>> PluginManager('heat.engine.resources') + + will load all modules in the heat.engine.resources package as well as + any user-supplied plugin modules. + + ''' + def packages(): + for package_name in extra_packages: + yield sys.modules[package_name] + + cfg.CONF.import_opt('plugin_dirs', 'heat.common.config') + yield plugin_loader.create_subpackage(cfg.CONF.plugin_dirs, + 'heat.engine') + + def modules(): + pkg_modules = itertools.imap(plugin_loader.load_modules, + packages()) + return itertools.chain.from_iterable(pkg_modules) + + self.modules = list(modules()) + + def map_to_modules(self, function): + '''Iterate over the results of calling a function on every module.''' + return itertools.imap(function, self.modules) + + +class PluginMapping(object): + '''A class for managing plugin mappings.''' + + def __init__(self, names, *args, **kwargs): + '''Initialise with the mapping name(s) and arguments. + + `names` can be a single name or a list of names. The first name found + in a given module is the one used. Each module is searched for a + function called _mapping() which is called to retrieve the + mappings provided by that module. Any other arguments passed will be + passed to the mapping functions. + + ''' + if isinstance(names, basestring): + names = [names] + + self.names = ['%s_mapping' % name for name in names] + self.args = args + self.kwargs = kwargs + + def load_from_module(self, module): + '''Return the mapping specified in the given module. + + If no such mapping is specified, an empty dictionary is returned. + ''' + for mapping_name in self.names: + mapping_func = getattr(module, mapping_name, None) + if callable(mapping_func): + fmt_data = {'mapping_name': mapping_name, 'module': module} + try: + mapping_dict = mapping_func(*self.args, **self.kwargs) + except Exception: + logger.error(_('Failed to load %(mapping_name)s ' + 'from %(module)s') % fmt_data) + raise + else: + if isinstance(mapping_dict, collections.Mapping): + return mapping_dict + elif mapping_dict is not None: + logger.error(_('Invalid type for %(mapping_name)s ' + 'from %(module)s') % fmt_data) + + return {} + + def load_all(self, plugin_manager): + '''Iterate over the mappings from all modules in the plugin manager. + + Mappings are returned as a list of (key, value) tuples. + ''' + mod_dicts = plugin_manager.map_to_modules(self.load_from_module) + return itertools.chain.from_iterable(d.iteritems() for d in mod_dicts) diff --git a/heat/tests/test_plugin_manager.py b/heat/tests/test_plugin_manager.py new file mode 100644 index 0000000000..e9011b4231 --- /dev/null +++ b/heat/tests/test_plugin_manager.py @@ -0,0 +1,99 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sys +import types + +from heat.tests.common import HeatTestCase + +from heat.engine import plugin_manager + + +def legacy_test_mapping(): + return {'foo': 'bar', 'baz': 'quux'} + + +def current_test_mapping(): + return {'blarg': 'wibble', 'bar': 'baz'} + + +def args_test_mapping(*args): + return dict(enumerate(args)) + + +def kwargs_test_mapping(**kwargs): + return kwargs + + +def error_test_mapping(): + raise MappingTestError + + +class MappingTestError(Exception): + pass + + +class TestPluginManager(HeatTestCase): + + @staticmethod + def module(): + return sys.modules[__name__] + + def test_load_single_mapping(self): + pm = plugin_manager.PluginMapping('current_test') + self.assertEqual(current_test_mapping(), + pm.load_from_module(self.module())) + + def test_load_first_alternative_mapping(self): + pm = plugin_manager.PluginMapping(['current_test', 'legacy_test']) + self.assertEqual(current_test_mapping(), + pm.load_from_module(self.module())) + + def test_load_second_alternative_mapping(self): + pm = plugin_manager.PluginMapping(['nonexist', 'current_test']) + self.assertEqual(current_test_mapping(), + pm.load_from_module(self.module())) + + def test_load_mapping_args(self): + pm = plugin_manager.PluginMapping('args_test', 'baz', 'quux') + expected = {0: 'baz', 1: 'quux'} + self.assertEqual(expected, pm.load_from_module(self.module())) + + def test_load_mapping_kwargs(self): + pm = plugin_manager.PluginMapping('kwargs_test', baz='quux') + self.assertEqual({'baz': 'quux'}, pm.load_from_module(self.module())) + + def test_load_mapping_non_existent(self): + pm = plugin_manager.PluginMapping('nonexist') + self.assertEqual({}, pm.load_from_module(self.module())) + + def test_load_mapping_error(self): + pm = plugin_manager.PluginMapping('error_test') + self.assertRaises(MappingTestError, pm.load_from_module, self.module()) + + def test_modules(self): + mgr = plugin_manager.PluginManager('heat.tests') + + for module in mgr.modules: + self.assertEqual(types.ModuleType, type(module)) + self.assertTrue(module.__name__.startswith('heat.tests') or + module.__name__.startswith('heat.engine.plugins')) + + def test_load_all(self): + mgr = plugin_manager.PluginManager('heat.tests') + pm = plugin_manager.PluginMapping('current_test') + + all_items = pm.load_all(mgr) + for item in current_test_mapping().iteritems(): + self.assertIn(item, all_items)