Split Provider configuration into files

Change-Id: Ie88895ceda3a620aa33999a980e04e144d8f32e6
This commit is contained in:
Yuval Brik 2016-04-04 23:36:30 +03:00
parent e90c2356d1
commit 7b24765f4a
24 changed files with 314 additions and 202 deletions

View File

@ -0,0 +1,9 @@
[provider]
name = OS Infra Provider
description = This provider uses OpenStack's own services (swift, cinder) as storage
id = cf56bd3e-97a7-4078-b6d5-f36246333fd9
# TODO(yuvalbr)
# bank = swift
# plugin = cinder_backup
# plugin = glance_backup
# plugin = neutron_backup

View File

@ -80,11 +80,6 @@ global_opts = [
cfg.IntOpt('lease_validity_window', cfg.IntOpt('lease_validity_window',
default=100, default=100,
help='validity_window for bank lease, in seconds'), help='validity_window for bank lease, in seconds'),
cfg.ListOpt('enabled_providers',
default=None,
help='A list of provider names to use. These provider names '
'should be backed by a unique [CONFIG] group '
'with its options'),
] ]
CONF.register_opts(global_opts) CONF.register_opts(global_opts)

View File

@ -45,6 +45,9 @@ class LeasePlugin(object):
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class BankPlugin(object): class BankPlugin(object):
def __init__(self, config=None):
self._config = config
@abc.abstractmethod @abc.abstractmethod
def create_object(self, key, value): def create_object(self, key, value):
return return

View File

@ -60,8 +60,6 @@ swift_client_opts = [
'making SSL connection to Swift.'), 'making SSL connection to Swift.'),
] ]
CONF = cfg.CONF
CONF.register_opts(swift_client_opts, "swift_client")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -70,19 +68,24 @@ class SwiftConnectionFailed(exception.SmaugException):
class SwiftBankPlugin(BankPlugin, LeasePlugin): class SwiftBankPlugin(BankPlugin, LeasePlugin):
def __init__(self, context, object_container): def __init__(self, config, context, object_container):
super(BankPlugin, self).__init__() super(SwiftBankPlugin, self).__init__(config)
self.context = context self.context = context
self.swift_retry_attempts = CONF.swift_client.bank_swift_retry_attempts self._config.register_opts(swift_client_opts, "swift_client")
self.swift_retry_backoff = CONF.swift_client.bank_swift_retry_backoff self.swift_retry_attempts = \
self.swift_auth_insecure = CONF.swift_client.bank_swift_auth_insecure self._config.swift_client.bank_swift_retry_attempts
self.swift_ca_cert_file = CONF.swift_client.bank_swift_ca_cert_file self.swift_retry_backoff = \
self.lease_expire_window = CONF.lease_expire_window self._config.swift_client.bank_swift_retry_backoff
self.lease_renew_window = CONF.lease_renew_window self.swift_auth_insecure = \
self._config.swift_client.bank_swift_auth_insecure
self.swift_ca_cert_file = \
self._config.swift_client.bank_swift_ca_cert_file
self.lease_expire_window = self._config.lease_expire_window
self.lease_renew_window = self._config.lease_renew_window
# TODO(luobin): # TODO(luobin):
# init lease_validity_window # init lease_validity_window
# according to lease_renew_window if not configured # according to lease_renew_window if not configured
self.lease_validity_window = CONF.lease_validity_window self.lease_validity_window = self._config.lease_validity_window
# TODO(luobin): create a uuid of this bank_plugin # TODO(luobin): create a uuid of this bank_plugin
self.owner_id = str(uuid.uuid4()) self.owner_id = str(uuid.uuid4())
@ -113,20 +116,20 @@ class SwiftBankPlugin(BankPlugin, LeasePlugin):
initial_delay=self.lease_renew_window) initial_delay=self.lease_renew_window)
def _setup_connection(self): def _setup_connection(self):
if CONF.swift_client.bank_swift_auth == "single_user": if self._config.swift_client.bank_swift_auth == "single_user":
connection = swift.Connection( connection = swift.Connection(
authurl=CONF.swift_client.bank_swift_auth_url, authurl=self._config.swift_client.bank_swift_auth_url,
auth_version=CONF.swift_client.bank_swift_auth_version, auth_version=self._config.swift_client.bank_swift_auth_version,
tenant_name=CONF.swift_client.bank_swift_tenant_name, tenant_name=self._config.swift_client.bank_swift_tenant_name,
user=CONF.swift_client.bank_swift_user, user=self._config.swift_client.bank_swift_user,
key=CONF.swift_client.bank_swift_key, key=self._config.swift_client.bank_swift_key,
retries=self.swift_retry_attempts, retries=self.swift_retry_attempts,
starting_backoff=self.swift_retry_backoff, starting_backoff=self.swift_retry_backoff,
insecure=self.swift_auth_insecure, insecure=self.swift_auth_insecure,
cacert=self.swift_ca_cert_file) cacert=self.swift_ca_cert_file)
else: else:
connection = swift.Connection( connection = swift.Connection(
preauthurl=CONF.swift_client.bank_swift_url, preauthurl=self._config.swift_client.bank_swift_url,
preauthtoken=self.context.auth_token, preauthtoken=self.context.auth_token,
retries=self.swift_retry_attempts, retries=self.swift_retry_attempts,
starting_backoff=self.swift_retry_backoff, starting_backoff=self.swift_retry_backoff,

View File

@ -12,6 +12,7 @@
import os import os
from oslo_config import cfg
from oslo_log import log as logging from oslo_log import log as logging
from oslo_utils import importutils from oslo_utils import importutils
from smaug.i18n import _LE from smaug.i18n import _LE
@ -36,11 +37,11 @@ class ClientFactory(object):
yield '%s.clients.%s' % (__package__, name) yield '%s.clients.%s' % (__package__, name)
@classmethod @classmethod
def create_client(cls, service, context): def create_client(cls, service, context, conf=cfg.CONF):
if not cls._factory: if not cls._factory:
cls._factory = {} cls._factory = {}
for module in cls._list_clients(): for module in cls._list_clients():
module = importutils.import_module(module) module = importutils.import_module(module)
cls._factory[module.SERVICE] = module cls._factory[module.SERVICE] = module
return cls._factory[service].create(context) return cls._factory[service].create(context, conf)

View File

@ -35,9 +35,10 @@ cfg.CONF.register_opts(cinder_client_opts, group=SERVICE + '_client')
CINDERCLIENT_VERSION = '2' CINDERCLIENT_VERSION = '2'
def create(context): def create(context, conf):
conf.register_opts(cinder_client_opts, group=SERVICE + '_client')
try: try:
url = utils.get_url(SERVICE, context, append_project=True) url = utils.get_url(SERVICE, context, conf, append_project=True)
except Exception: except Exception:
LOG.error(_LE("Get cinder service endpoint url failed.")) LOG.error(_LE("Get cinder service endpoint url failed."))
raise raise

View File

@ -35,9 +35,10 @@ cfg.CONF.register_opts(glance_client_opts, group=SERVICE + '_client')
GLANCECLIENT_VERSION = '2' GLANCECLIENT_VERSION = '2'
def create(context): def create(context, conf):
conf.register_opts(glance_client_opts, group=SERVICE + '_client')
try: try:
url = utils.get_url(SERVICE, context) url = utils.get_url(SERVICE, context, conf)
except Exception: except Exception:
LOG.error(_LE("Get glance service endpoint url failed")) LOG.error(_LE("Get glance service endpoint url failed"))
raise raise

View File

@ -33,9 +33,10 @@ neutron_client_opts = [
cfg.CONF.register_opts(neutron_client_opts, group=SERVICE + '_client') cfg.CONF.register_opts(neutron_client_opts, group=SERVICE + '_client')
def create(context): def create(context, conf):
conf.register_opts(neutron_client_opts, group=SERVICE + '_client')
try: try:
url = utils.get_url(SERVICE, context) url = utils.get_url(SERVICE, context, conf)
except Exception: except Exception:
LOG.error(_LE("Get neutron service endpoint url failed")) LOG.error(_LE("Get neutron service endpoint url failed"))
raise raise

View File

@ -36,9 +36,10 @@ cfg.CONF.register_opts(nova_client_opts, group=SERVICE + '_client')
NOVACLIENT_VERSION = '2' NOVACLIENT_VERSION = '2'
def create(context): def create(context, conf):
conf.register_opts(nova_client_opts, group=SERVICE + '_client')
try: try:
url = utils.get_url(SERVICE, context, append_project=True) url = utils.get_url(SERVICE, context, conf, append_project=True)
except Exception: except Exception:
LOG.error(_LE("Get nova service endpoint url failed.")) LOG.error(_LE("Get nova service endpoint url failed."))
raise raise

View File

@ -24,6 +24,8 @@ LOG = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class ProtectionPlugin(object): class ProtectionPlugin(object):
def __init__(self, config=None):
self._config = config
@abc.abstractmethod @abc.abstractmethod
def get_supported_resources_types(self): def get_supported_resources_types(self):

View File

@ -10,22 +10,28 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import os
from oslo_config import cfg from oslo_config import cfg
from oslo_log import log as logging from oslo_log import log as logging
from smaug.common import constants
from smaug.i18n import _LE from smaug.i18n import _LE
from smaug.services.protection import checkpoint from smaug.services.protection import checkpoint
from smaug import utils from smaug import utils
provider_opt = [ provider_opts = [
cfg.MultiStrOpt('plugin', cfg.MultiStrOpt('plugin',
default='', default='',
help='plugins to use for protection'), help='plugins to use for protection'),
cfg.StrOpt('bank',
default='',
help='bank plugin to use for storage'),
cfg.StrOpt('description', cfg.StrOpt('description',
default='', default='',
help='the description of provider'), help='the description of provider'),
cfg.StrOpt('provider_id', cfg.StrOpt('name',
default='',
help='the name of provider'),
cfg.StrOpt('id',
default='', default='',
help='the provider id') help='the provider id')
] ]
@ -35,13 +41,20 @@ LOG = logging.getLogger(__name__)
PROTECTION_NAMESPACE = 'smaug.protections' PROTECTION_NAMESPACE = 'smaug.protections'
CONF.register_opt(cfg.StrOpt('provider_config_dir',
default='providers.d',
help='Configuration directory for providers.'
' Absolute path, or relative to smaug '
' configuration directory.'))
class PluggableProtectionProvider(object): class PluggableProtectionProvider(object):
def __init__(self, provider_id, provider_name, description, plugins): def __init__(self, provider_config):
super(PluggableProtectionProvider, self).__init__() super(PluggableProtectionProvider, self).__init__()
self._id = provider_id self._config = provider_config
self._name = provider_name self._id = self._config.provider.id
self._description = description self._name = self._config.provider.name
self._description = self._config.provider.description
self._extended_info_schema = {'options_schema': {}, self._extended_info_schema = {'options_schema': {},
'restore_schema': {}, 'restore_schema': {},
'saved_info_schema': {}} 'saved_info_schema': {}}
@ -49,12 +62,21 @@ class PluggableProtectionProvider(object):
self._bank_plugin = None self._bank_plugin = None
self._plugin_map = {} self._plugin_map = {}
self._load_plugins(plugins=plugins) if hasattr(self._config.provider, 'bank') \
and not self._config.provider.bank:
raise ImportError("Empty bank")
self._load_bank(self._config.provider.bank)
if hasattr(self._config.provider, 'plugin'):
for plugin_name in self._config.provider.plugin:
if not plugin_name:
raise ImportError("Empty protection plugin")
self._load_plugin(plugin_name)
if self._bank_plugin: if self._bank_plugin:
self.checkpoint_collection = checkpoint.CheckpointCollection( self.checkpoint_collection = checkpoint.CheckpointCollection(
self._bank_plugin) self._bank_plugin)
else: else:
LOG.error(_LE('Bank plugin not exist,check your configuration')) LOG.error(_LE('Bank plugin not exist, check your configuration'))
@property @property
def id(self): def id(self):
@ -72,27 +94,43 @@ class PluggableProtectionProvider(object):
def extended_info_schema(self): def extended_info_schema(self):
return self._extended_info_schema return self._extended_info_schema
def _load_plugins(self, plugins): @property
for plugin_name in plugins: def bank(self):
return self._bank_plugin
@property
def plugins(self):
return self._plugin_map
def _load_bank(self, bank_name):
try: try:
plugin = utils.load_plugin(PROTECTION_NAMESPACE, plugin_name) plugin = utils.load_plugin(PROTECTION_NAMESPACE, bank_name,
self._config)
except Exception: except Exception:
LOG.exception(_LE("Load protection plugin: %s failed."), LOG.error(_LE("Load bank plugin: '%s' failed."), bank_name)
plugin_name) raise
else:
self._bank_plugin = plugin
def _load_plugin(self, plugin_name):
try:
plugin = utils.load_plugin(PROTECTION_NAMESPACE, plugin_name,
self._config)
except Exception:
LOG.error(_LE("Load protection plugin: '%s' failed."), plugin_name)
raise raise
else: else:
self._plugin_map[plugin_name] = plugin self._plugin_map[plugin_name] = plugin
if constants.PLUGIN_BANK in plugin_name.lower(): for resource in plugin.get_supported_resources_types():
self._bank_plugin = plugin
if hasattr(plugin, 'get_options_schema'): if hasattr(plugin, 'get_options_schema'):
self._extended_info_schema['options_schema'][plugin_name] \ self._extended_info_schema['options_schema'][resource] \
= plugin.get_options_schema() = plugin.get_options_schema(resource)
if hasattr(plugin, 'get_restore_schema'): if hasattr(plugin, 'get_restore_schema'):
self._extended_info_schema['restore_schema'][plugin_name] \ self._extended_info_schema['restore_schema'][resource] \
= plugin.get_restore_schema() = plugin.get_restore_schema(resource)
if hasattr(plugin, 'get_saved_info_schema'): if hasattr(plugin, 'get_saved_info_schema'):
self._extended_info_schema['saved_info_schema'][plugin_name] \ self._extended_info_schema['saved_info_schema'][resource] \
= plugin.get_saved_info_schema() = plugin.get_saved_info_schema(resource)
def get_checkpoint_collection(self): def get_checkpoint_collection(self):
return self.checkpoint_collection return self.checkpoint_collection
@ -109,51 +147,29 @@ class ProviderRegistry(object):
self._load_providers() self._load_providers()
def _load_providers(self): def _load_providers(self):
"""load provider """load provider"""
config_dir = utils.find_config(CONF.provider_config_dir)
smaug.conf example: for config_file in os.listdir(config_dir):
[default] if not config_file.endswith('.conf'):
enabled_providers=provider1,provider2
[provider1]
provider_id='' configured by admin
plugin=BANK define in setup.cfg
plugin=VolumeProtectionPlugin define in setup.cfg
description='the description of provider1'
[provider2]
provider_id='' configured by admin
plugin=BANK define in setup.cfg
plugin=VolumeProtectionPlugin define in setup.cfg
plugin=ServerProtectionPlugin define in setup.cfg
description='the description of provider2'
"""
if CONF.enabled_providers:
for provider_name in CONF.enabled_providers:
CONF.register_opts(provider_opt, group=provider_name)
plugins = getattr(CONF, provider_name).plugin
description = getattr(CONF, provider_name).description
provider_id = getattr(CONF, provider_name).provider_id
if not all([plugins, provider_id]):
LOG.error(_LE("Invalid provider:%s,check provider"
" configuration"),
provider_name)
continue continue
config_path = os.path.abspath(os.path.join(config_dir,
config_file))
provider_config = cfg.ConfigOpts()
provider_config(args=['--config-file=' + config_path])
provider_config.register_opts(provider_opts, 'provider')
try: try:
provider = PluggableProtectionProvider(provider_id, provider = PluggableProtectionProvider(provider_config)
provider_name,
description,
plugins)
except Exception: except Exception:
LOG.exception(_LE("Load provider: %s failed."), LOG.error(_LE("Load provider: %s failed."),
provider_name) provider_config.provider.name)
else: else:
self.providers[provider_id] = provider self.providers[provider.id] = provider
def list_providers(self, list_option=None): def list_providers(self):
if not list_option:
return [dict(id=provider.id, name=provider.name, return [dict(id=provider.id, name=provider.name,
description=provider.description) description=provider.description)
for provider in self.providers.values()] for provider in self.providers.values()]
# It seems that we don't need list_option
def show_provider(self, provider_id): def show_provider(self, provider_id):
return self.providers.get(provider_id, None) return self.providers.get(provider_id, None)

View File

@ -10,7 +10,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from oslo_config import cfg
from smaug import exception from smaug import exception
from smaug.i18n import _ from smaug.i18n import _
@ -39,9 +38,9 @@ def _parse_service_endpoint(endpoint_url, context, append_project=False):
else endpoint_url else endpoint_url
def get_url(service, context, append_project=False): def get_url(service, context, conf, append_project=False):
'''Return the url of given service endpoint.''' '''Return the url of given service endpoint.'''
client_conf = getattr(cfg.CONF, service + '_client') client_conf = getattr(conf, service + '_client')
endpoint = getattr(client_conf, service + '_endpoint') endpoint = getattr(client_conf, service + '_endpoint')
if endpoint is not None: if endpoint is not None:

View File

@ -35,13 +35,13 @@ class CinderClientTest(base.TestCase):
cfg.CONF.set_default('cinder_endpoint', cfg.CONF.set_default('cinder_endpoint',
'http://127.0.0.1:8776/v2', 'http://127.0.0.1:8776/v2',
'cinder_client') 'cinder_client')
client = cinder.create(self._context) client = cinder.create(self._context, cfg.CONF)
self.assertEqual('volumev2', client.client.service_type) self.assertEqual('volumev2', client.client.service_type)
self.assertEqual('http://127.0.0.1:8776/v2/abcd', self.assertEqual('http://127.0.0.1:8776/v2/abcd',
client.client.management_url) client.client.management_url)
def test_create_client_by_catalog(self): def test_create_client_by_catalog(self):
client = cinder.create(self._context) client = cinder.create(self._context, cfg.CONF)
self.assertEqual('volumev2', client.client.service_type) self.assertEqual('volumev2', client.client.service_type)
self.assertEqual('http://127.0.0.1:8776/v2/abcd', self.assertEqual('http://127.0.0.1:8776/v2/abcd',
client.client.management_url) client.client.management_url)

View File

@ -40,9 +40,9 @@ class GlanceClientTest(base.TestCase):
cfg.CONF.set_default('glance_endpoint', cfg.CONF.set_default('glance_endpoint',
'http://127.0.0.1:9292', 'http://127.0.0.1:9292',
'glance_client') 'glance_client')
gc = glance.create(self._context) gc = glance.create(self._context, cfg.CONF)
self.assertEqual('http://127.0.0.1:9292', gc.http_client.endpoint) self.assertEqual('http://127.0.0.1:9292', gc.http_client.endpoint)
def test_create_client_by_catalog(self): def test_create_client_by_catalog(self):
gc = glance.create(self._context) gc = glance.create(self._context, cfg.CONF)
self.assertEqual('http://127.0.0.1:9292', gc.http_client.endpoint) self.assertEqual('http://127.0.0.1:9292', gc.http_client.endpoint)

View File

@ -40,9 +40,9 @@ class NeutronClientTest(base.TestCase):
cfg.CONF.set_default('neutron_endpoint', cfg.CONF.set_default('neutron_endpoint',
'http://127.0.0.1:9696', 'http://127.0.0.1:9696',
'neutron_client') 'neutron_client')
nc = neutron.create(self._context) nc = neutron.create(self._context, cfg.CONF)
self.assertEqual('http://127.0.0.1:9696', nc.httpclient.endpoint_url) self.assertEqual('http://127.0.0.1:9696', nc.httpclient.endpoint_url)
def test_create_client_by_catalog(self): def test_create_client_by_catalog(self):
nc = neutron.create(self._context) nc = neutron.create(self._context, cfg.CONF)
self.assertEqual('http://127.0.0.1:9696', nc.httpclient.endpoint_url) self.assertEqual('http://127.0.0.1:9696', nc.httpclient.endpoint_url)

View File

@ -35,13 +35,13 @@ class NovaClientTest(base.TestCase):
cfg.CONF.set_default('nova_endpoint', cfg.CONF.set_default('nova_endpoint',
'http://127.0.0.1:8774/v2.1', 'http://127.0.0.1:8774/v2.1',
'nova_client') 'nova_client')
client = nova.create(self._context) client = nova.create(self._context, cfg.CONF)
self.assertEqual('compute', client.client.service_type) self.assertEqual('compute', client.client.service_type)
self.assertEqual('http://127.0.0.1:8774/v2.1/abcd', self.assertEqual('http://127.0.0.1:8774/v2.1/abcd',
client.client.management_url) client.client.management_url)
def test_create_client_by_catalog(self): def test_create_client_by_catalog(self):
client = nova.create(self._context) client = nova.create(self._context, cfg.CONF)
self.assertEqual('compute', client.client.service_type) self.assertEqual('compute', client.client.service_type)
self.assertEqual('http://127.0.0.1:8774/v2.1/abcd', self.assertEqual('http://127.0.0.1:8774/v2.1/abcd',
client.client.management_url) client.client.management_url)

View File

@ -28,3 +28,5 @@ def set_defaults(conf):
conf.set_default('auth_strategy', 'noauth') conf.set_default('auth_strategy', 'noauth')
conf.set_default('state_path', os.path.abspath( conf.set_default('state_path', os.path.abspath(
os.path.join(os.path.dirname(__file__), '..', '..', '..'))) os.path.join(os.path.dirname(__file__), '..', '..', '..')))
conf.set_default('provider_config_dir',
os.path.join(os.path.dirname(__file__), 'fake_providers'))

View File

@ -0,0 +1,40 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_config import cfg
from smaug.services.protection import bank_plugin
fake_bank_opts = [
cfg.StrOpt('fake_host'),
]
class FakeBankPlugin(bank_plugin.BankPlugin):
def __init__(self, config=None):
super(FakeBankPlugin, self).__init__(config)
config.register_opts(fake_bank_opts, 'fake_bank')
def create_object(self, key, value):
return
def update_object(self, key, value):
return
def get_object(self, key):
return
def list_objects(self, prefix=None, limit=None, marker=None):
return
def delete_object(self, key):
return

View File

@ -0,0 +1,49 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_config import cfg
from smaug.services.protection import protection_plugin
fake_plugin_opts = [
cfg.StrOpt('fake_user'),
]
class FakeProtectionPlugin(protection_plugin.ProtectionPlugin):
def __init__(self, config=None):
super(FakeProtectionPlugin, self).__init__(config)
config.register_opts(fake_plugin_opts, 'fake_plugin')
def get_supported_resources_types(self):
return ['Test::Resource']
def get_options_schema(self, resource_type):
return []
def get_saved_info_schema(self, resource_type):
return []
def get_restore_schema(self, resource_type):
return []
def get_saved_info(self, metadata_store, resource):
pass
def get_protection_stats(self, protection_id):
pass
def on_resource_start(self, context):
pass
def on_resource_end(self, context):
pass

View File

@ -0,0 +1,12 @@
[provider]
name = fake_provider1
id = fake_id1
description = Test Provider 1
bank = smaug.tests.unit.fake_bank.FakeBankPlugin
plugin = smaug.tests.unit.fake_protection.FakeProtectionPlugin
[fake_plugin]
fake_user = user
[fake_bank]
fake_host = thor

View File

@ -0,0 +1,4 @@
[provider]
name = fake_provider2
id = fake_id2
description = Test Provider 2

View File

@ -11,75 +11,48 @@
# under the License. # under the License.
import mock import mock
from oslo_config import cfg from oslo_config import cfg
from smaug.services.protection import provider from smaug.services.protection import provider
from smaug.tests import base from smaug.tests import base
provider_opt = [
cfg.MultiStrOpt('plugin',
default='',
help='plugins to use for protection'),
cfg.StrOpt('description',
default='',
help='the description of provider'),
cfg.StrOpt('provider_id',
default='',
help='the provider id')
]
CONF = cfg.CONF CONF = cfg.CONF
class ProviderRegistryTest(base.TestCase): class ProviderRegistryTest(base.TestCase):
def setUp(self): def setUp(self):
super(ProviderRegistryTest, self).setUp() super(ProviderRegistryTest, self).setUp()
CONF.set_override('enabled_providers',
['provider1', 'provider2'])
CONF.register_opts(provider_opt, group='provider1')
CONF.register_opts(provider_opt, group='provider2')
CONF.set_override('plugin', ['SERVER', 'VOLUME'],
group='provider1')
CONF.set_override('plugin', ['SERVER'],
group='provider2')
CONF.set_override('description', 'FAKE1', group='provider1')
CONF.set_override('description', 'FAKE2', group='provider2')
CONF.set_override('provider_id', 'id1', group='provider1')
CONF.set_override('provider_id', 'id2', group='provider2')
@mock.patch.object(provider.PluggableProtectionProvider, '_load_plugins') @mock.patch.object(provider.PluggableProtectionProvider, '_load_bank')
def test_load_providers(self, mock_load_plugins): @mock.patch.object(provider.PluggableProtectionProvider, '_load_plugin')
CONF.set_override('plugin', ['SERVER'], def test_load_providers(self, mock_load_bank, mock_load_plugin):
group='provider2')
pr = provider.ProviderRegistry() pr = provider.ProviderRegistry()
self.assertTrue(mock_load_plugins.called) self.assertEqual(mock_load_plugin.call_count, 1)
self.assertEqual(len(pr.providers), 2) self.assertEqual(mock_load_bank.call_count, 1)
@mock.patch.object(provider.PluggableProtectionProvider, '_load_plugins')
def test_load_providers_with_no_plugins(self, mock_load_plugins):
CONF.set_override('plugin', None,
group='provider2')
pr = provider.ProviderRegistry()
self.assertEqual(mock_load_plugins.call_count, 1)
self.assertEqual(len(pr.providers), 1) self.assertEqual(len(pr.providers), 1)
@mock.patch.object(provider.PluggableProtectionProvider, '_load_plugins') self.assertEqual(pr.providers['fake_id1'].name, 'fake_provider1')
def test_list_provider(self, mock_load_plugins): self.assertNotIn('fake_provider2', pr.providers)
CONF.set_override('plugin', ['SERVER'],
group='provider2')
pr = provider.ProviderRegistry()
self.assertEqual(2, len(pr.list_providers()))
@mock.patch.object(provider.PluggableProtectionProvider, '_load_plugins') def test_provider_bank_config(self):
def test_show_provider(self, mock_load_plugins): pr = provider.ProviderRegistry()
CONF.set_override('plugin', ['SERVER'], provider1 = pr.show_provider('fake_id1')
group='provider2') self.assertEqual(provider1.bank._config.fake_bank.fake_host, 'thor')
def test_provider_plugin_config(self):
pr = provider.ProviderRegistry()
provider1 = pr.show_provider('fake_id1')
plugin_name = 'smaug.tests.unit.fake_protection.FakeProtectionPlugin'
self.assertEqual(
provider1.plugins[plugin_name]._config.fake_plugin.fake_user,
'user')
def test_list_provider(self):
pr = provider.ProviderRegistry()
self.assertEqual(1, len(pr.list_providers()))
def test_show_provider(self):
pr = provider.ProviderRegistry() pr = provider.ProviderRegistry()
provider_list = pr.list_providers() provider_list = pr.list_providers()
for provider_node in provider_list: for provider_node in provider_list:
self.assertTrue(pr.show_provider(provider_node['id'])) self.assertTrue(pr.show_provider(provider_node['id']))
def tearDown(self):
CONF.register_opts(provider_opt, group='provider1')
CONF.register_opts(provider_opt, group='provider2')
CONF.set_override('enabled_providers',
None)
super(ProviderRegistryTest, self).tearDown()

View File

@ -41,7 +41,7 @@ class SwiftBankPluginTest(base.TestCase):
import_str=import_str) import_str=import_str)
swift.Connection = mock.MagicMock() swift.Connection = mock.MagicMock()
swift.Connection.return_value = self.fake_connection swift.Connection.return_value = self.fake_connection
self.swift_bank_plugin = swift_bank_plugin_cls(None, self.swift_bank_plugin = swift_bank_plugin_cls(CONF, None,
self.object_container) self.object_container)
def test_acquire_lease(self): def test_acquire_lease(self):

View File

@ -125,7 +125,7 @@ def get_bool_param(param_string, params):
return strutils.bool_from_string(param, strict=True) return strutils.bool_from_string(param, strict=True)
def load_plugin(namespace, plugin_name): def load_plugin(namespace, plugin_name, *args, **kwargs):
try: try:
# Try to resolve plugin by name # Try to resolve plugin by name
mgr = driver.DriverManager(namespace, plugin_name) mgr = driver.DriverManager(namespace, plugin_name)
@ -138,4 +138,4 @@ def load_plugin(namespace, plugin_name):
LOG.exception(_LE("Error loading plugin by name, %s"), e1) LOG.exception(_LE("Error loading plugin by name, %s"), e1)
LOG.exception(_LE("Error loading plugin by class, %s"), e2) LOG.exception(_LE("Error loading plugin by class, %s"), e2)
raise ImportError(_("Class not found.")) raise ImportError(_("Class not found."))
return plugin_class() return plugin_class(*args, **kwargs)