Have Drivers create Providers

Use the new Driver class to create instances of Providers

Change-Id: Idfbde8d773a971133b49fbc318385893be293fac
This commit is contained in:
James E. Blair 2018-05-31 15:09:15 -07:00 committed by David Shrewsbury
parent 82d8c51250
commit e20858755f
10 changed files with 42 additions and 42 deletions

View File

@ -204,7 +204,7 @@ def get_provider_config(provider):
# Ensure legacy configuration still works when using fake cloud
if provider.get('name', '').startswith('fake'):
provider['driver'] = 'fake'
driver = Drivers._get(provider['driver'])
driver = Drivers.get(provider['driver'])
return driver.getProviderConfig(provider)
@ -234,7 +234,7 @@ def loadConfig(config_path):
config = openConfig(config_path)
# Call driver config reset now to clean global hooks like os_client_config
for driver in Drivers._drivers.values():
for driver in Drivers.drivers.values():
driver.reset()
newconfig = Config()

View File

@ -32,7 +32,6 @@ class Drivers:
log = logging.getLogger("nodepool.driver.Drivers")
drivers = {}
_drivers = {} # TODO: replace drivers
drivers_paths = None
@staticmethod
@ -72,22 +71,6 @@ class Drivers:
if not os.path.isdir(driver_path) or \
"__init__.py" not in os.listdir(driver_path):
continue
Drivers.log.debug("%s: loading driver", driver_path)
driver_obj = {}
for name, parent_class in (
("provider", Provider),
):
driver_obj[name] = Drivers._load_class(
driver, os.path.join(driver_path, "%s.py" % name),
parent_class)
if not driver_obj[name]:
break
if not driver_obj[name]:
Drivers.log.error(
"%s: skipping incorrect driver from %s.py",
driver_path, name)
continue
Drivers.drivers[driver] = driver_obj
driver_obj = Drivers._load_class(
driver, os.path.join(driver_path, "__init__.py"),
Driver)
@ -96,7 +79,7 @@ class Drivers:
"%s: skipping incorrect driver from __init__.py",
driver_path)
continue
Drivers._drivers[driver] = driver_obj()
Drivers.drivers[driver] = driver_obj()
Drivers.drivers_paths = drivers_paths
@ -109,16 +92,6 @@ class Drivers:
except KeyError:
raise RuntimeError("%s: unknown driver" % name)
# TODO: replace get
@staticmethod
def _get(name):
if not Drivers._drivers:
Drivers.load()
try:
return Drivers._drivers[name]
except KeyError:
raise RuntimeError("%s: unknown driver" % name)
class Driver(object, metaclass=abc.ABCMeta):
"""The Driver interface
@ -145,6 +118,18 @@ class Driver(object, metaclass=abc.ABCMeta):
"""
pass
@abc.abstractmethod
def getProvider(self, provider_config, use_taskmanager):
"""Return a Provider instance
:arg dict provider_config: A ProviderConfig instance
:arg bool use_taskmanager: Whether this provider should use a
task manager (i.e., perform synchronous or asynchronous
operations).
"""
pass
class Provider(object, metaclass=abc.ABCMeta):
"""The Provider interface

View File

@ -16,6 +16,7 @@ import os_client_config
from nodepool.driver import Driver
from nodepool.driver.fake.config import FakeProviderConfig
from nodepool.driver.fake.provider import FakeProvider
class FakeDriver(Driver):
@ -28,3 +29,6 @@ class FakeDriver(Driver):
def getProviderConfig(self, provider):
return FakeProviderConfig(self, provider)
def getProvider(self, provider_config, use_taskmanager):
return FakeProvider(provider_config, use_taskmanager)

View File

@ -16,6 +16,7 @@ import os_client_config
from nodepool.driver import Driver
from nodepool.driver.openstack.config import OpenStackProviderConfig
from nodepool.driver.openstack.provider import OpenStackProvider
class OpenStackDriver(Driver):
@ -28,3 +29,6 @@ class OpenStackDriver(Driver):
def getProviderConfig(self, provider):
return OpenStackProviderConfig(self, provider)
def getProvider(self, provider_config, use_taskmanager):
return OpenStackProvider(provider_config, use_taskmanager)

View File

@ -13,9 +13,13 @@
# limitations under the License.
from nodepool.driver import Driver
from nodepool.driver.static.config import StaticProviderConfig
from nodepool.driver.static import config
from nodepool.driver.static import provider
class StaticDriver(Driver):
def getProviderConfig(self, provider):
return StaticProviderConfig(provider)
return config.StaticProviderConfig(provider)
def getProvider(self, provider_config, use_taskmanager):
return provider.StaticNodeProvider(provider_config, use_taskmanager)

View File

@ -13,9 +13,13 @@
# limitations under the License.
from nodepool.driver import Driver
from nodepool.driver.test.config import TestConfig
from nodepool.driver.test import config
from nodepool.driver.test import provider
class TestDriver(Driver):
def getProviderConfig(self, provider):
return TestConfig(provider)
return config.TestConfig(provider)
def getProvider(self, provider_config, use_taskmanager):
return provider.TestProvider(provider_config)

View File

@ -19,7 +19,7 @@ from nodepool.driver.test import handler
class TestProvider(Provider):
def __init__(self, provider, *args):
def __init__(self, provider):
self.provider = provider
def start(self):

View File

@ -23,7 +23,7 @@ from nodepool.driver import Drivers
def get_provider(provider, use_taskmanager):
driver = Drivers.get(provider.driver.name)
return driver['provider'](provider, use_taskmanager)
return driver.getProvider(provider, use_taskmanager)
class ProviderManager(object):

View File

@ -18,7 +18,6 @@ import uuid
import fixtures
from nodepool import builder, exceptions, tests
from nodepool.driver import Drivers
from nodepool.driver.fake import provider as fakeprovider
from nodepool import zk
@ -121,7 +120,7 @@ class TestNodePoolBuilder(tests.DBTestCase):
return fake_client
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'], '_getClient',
fakeprovider.FakeProvider, '_getClient',
get_fake_client))
configfile = self.setup_config('node.yaml')

View File

@ -21,7 +21,7 @@ import mock
from nodepool import tests
from nodepool import zk
from nodepool.driver import Drivers
from nodepool.driver.fake import provider as fakeprovider
import nodepool.launcher
from kazoo import exceptions as kze
@ -130,7 +130,7 @@ class TestLauncher(tests.DBTestCase):
def fake_get_quota():
return (max_cores, max_instances, max_ram)
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
fakeprovider.FakeProvider.fake_cloud, '_get_quota',
fake_get_quota
))
@ -265,7 +265,7 @@ class TestLauncher(tests.DBTestCase):
def fake_get_quota():
return (max_cores, max_instances, max_ram)
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'].fake_cloud, '_get_quota',
fakeprovider.FakeProvider.fake_cloud, '_get_quota',
fake_get_quota
))
@ -653,7 +653,7 @@ class TestLauncher(tests.DBTestCase):
raise RuntimeError('Fake Error')
self.useFixture(fixtures.MockPatchObject(
Drivers.get('fake')['provider'], 'deleteServer', fail_delete))
fakeprovider.FakeProvider, 'deleteServer', fail_delete))
configfile = self.setup_config('node.yaml')
pool = self.useNodepool(configfile, watermark_sleep=1)