Add a plugin_manager module

As the number of plugin types increases, this will make it easier to load
them all.

Change-Id: Ie9eb31c110ea313062e03cf59a2e3021058254f9
This commit is contained in:
Zane Bitter 2014-02-25 13:39:17 -05:00
parent 480087cfb4
commit 8a3eb2f9a5
2 changed files with 214 additions and 0 deletions

View File

@ -0,0 +1,115 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import itertools
import sys
from oslo.config import cfg
from heat.openstack.common import log
from heat.common import plugin_loader
logger = log.getLogger(__name__)
class PluginManager(object):
'''A class for managing plugin modules.'''
def __init__(self, *extra_packages):
'''Initialise the Heat Engine plugin package, and any others.
The heat.engine.plugins package is always created, if it does not
exist, from the plugin directories specified in the config file, and
searched for modules. In addition, any extra packages specified are
also searched for modules. e.g.
>>> PluginManager('heat.engine.resources')
will load all modules in the heat.engine.resources package as well as
any user-supplied plugin modules.
'''
def packages():
for package_name in extra_packages:
yield sys.modules[package_name]
cfg.CONF.import_opt('plugin_dirs', 'heat.common.config')
yield plugin_loader.create_subpackage(cfg.CONF.plugin_dirs,
'heat.engine')
def modules():
pkg_modules = itertools.imap(plugin_loader.load_modules,
packages())
return itertools.chain.from_iterable(pkg_modules)
self.modules = list(modules())
def map_to_modules(self, function):
'''Iterate over the results of calling a function on every module.'''
return itertools.imap(function, self.modules)
class PluginMapping(object):
'''A class for managing plugin mappings.'''
def __init__(self, names, *args, **kwargs):
'''Initialise with the mapping name(s) and arguments.
`names` can be a single name or a list of names. The first name found
in a given module is the one used. Each module is searched for a
function called <name>_mapping() which is called to retrieve the
mappings provided by that module. Any other arguments passed will be
passed to the mapping functions.
'''
if isinstance(names, basestring):
names = [names]
self.names = ['%s_mapping' % name for name in names]
self.args = args
self.kwargs = kwargs
def load_from_module(self, module):
'''Return the mapping specified in the given module.
If no such mapping is specified, an empty dictionary is returned.
'''
for mapping_name in self.names:
mapping_func = getattr(module, mapping_name, None)
if callable(mapping_func):
fmt_data = {'mapping_name': mapping_name, 'module': module}
try:
mapping_dict = mapping_func(*self.args, **self.kwargs)
except Exception:
logger.error(_('Failed to load %(mapping_name)s '
'from %(module)s') % fmt_data)
raise
else:
if isinstance(mapping_dict, collections.Mapping):
return mapping_dict
elif mapping_dict is not None:
logger.error(_('Invalid type for %(mapping_name)s '
'from %(module)s') % fmt_data)
return {}
def load_all(self, plugin_manager):
'''Iterate over the mappings from all modules in the plugin manager.
Mappings are returned as a list of (key, value) tuples.
'''
mod_dicts = plugin_manager.map_to_modules(self.load_from_module)
return itertools.chain.from_iterable(d.iteritems() for d in mod_dicts)

View File

@ -0,0 +1,99 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import sys
import types
from heat.tests.common import HeatTestCase
from heat.engine import plugin_manager
def legacy_test_mapping():
return {'foo': 'bar', 'baz': 'quux'}
def current_test_mapping():
return {'blarg': 'wibble', 'bar': 'baz'}
def args_test_mapping(*args):
return dict(enumerate(args))
def kwargs_test_mapping(**kwargs):
return kwargs
def error_test_mapping():
raise MappingTestError
class MappingTestError(Exception):
pass
class TestPluginManager(HeatTestCase):
@staticmethod
def module():
return sys.modules[__name__]
def test_load_single_mapping(self):
pm = plugin_manager.PluginMapping('current_test')
self.assertEqual(current_test_mapping(),
pm.load_from_module(self.module()))
def test_load_first_alternative_mapping(self):
pm = plugin_manager.PluginMapping(['current_test', 'legacy_test'])
self.assertEqual(current_test_mapping(),
pm.load_from_module(self.module()))
def test_load_second_alternative_mapping(self):
pm = plugin_manager.PluginMapping(['nonexist', 'current_test'])
self.assertEqual(current_test_mapping(),
pm.load_from_module(self.module()))
def test_load_mapping_args(self):
pm = plugin_manager.PluginMapping('args_test', 'baz', 'quux')
expected = {0: 'baz', 1: 'quux'}
self.assertEqual(expected, pm.load_from_module(self.module()))
def test_load_mapping_kwargs(self):
pm = plugin_manager.PluginMapping('kwargs_test', baz='quux')
self.assertEqual({'baz': 'quux'}, pm.load_from_module(self.module()))
def test_load_mapping_non_existent(self):
pm = plugin_manager.PluginMapping('nonexist')
self.assertEqual({}, pm.load_from_module(self.module()))
def test_load_mapping_error(self):
pm = plugin_manager.PluginMapping('error_test')
self.assertRaises(MappingTestError, pm.load_from_module, self.module())
def test_modules(self):
mgr = plugin_manager.PluginManager('heat.tests')
for module in mgr.modules:
self.assertEqual(types.ModuleType, type(module))
self.assertTrue(module.__name__.startswith('heat.tests') or
module.__name__.startswith('heat.engine.plugins'))
def test_load_all(self):
mgr = plugin_manager.PluginManager('heat.tests')
pm = plugin_manager.PluginMapping('current_test')
all_items = pm.load_all(mgr)
for item in current_test_mapping().iteritems():
self.assertIn(item, all_items)