Extract common config parsing for ProviderConfig

Adds a ProviderConfig class method that can be called to get
the config schema for the common config options in a Provider.
Drivers are modified to call this method.

Change-Id: Ib67256dddc06d13eb7683226edaa8c8c10a73326
This commit is contained in:
David Shrewsbury 2018-12-13 15:14:20 -05:00 committed by Tobias Henkel
parent a19dffd916
commit d6ef934b70
6 changed files with 24 additions and 9 deletions

View File

@ -14,6 +14,7 @@ import logging
import voluptuous as v import voluptuous as v
import yaml import yaml
from nodepool.driver import ProviderConfig
from nodepool.config import get_provider_config from nodepool.config import get_provider_config
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -26,11 +27,7 @@ class ConfigValidator:
self.config_file = config_file self.config_file = config_file
def validate(self): def validate(self):
provider = { provider = ProviderConfig.getCommonSchemaDict()
'name': v.Required(str),
'driver': str,
'max-concurrency': int,
}
label = { label = {
'name': str, 'name': str,

View File

@ -22,6 +22,7 @@ import importlib
import logging import logging
import math import math
import os import os
import voluptuous as v
from nodepool import zk from nodepool import zk
from nodepool import exceptions from nodepool import exceptions
@ -910,6 +911,14 @@ class ProviderConfig(ConfigValue, metaclass=abc.ABCMeta):
def __repr__(self): def __repr__(self):
return "<Provider %s>" % self.name return "<Provider %s>" % self.name
@classmethod
def getCommonSchemaDict(self):
return {
v.Required('name'): str,
'driver': str,
'max-concurrency': int
}
@property @property
@abc.abstractmethod @abc.abstractmethod
def pools(self): def pools(self):

View File

@ -109,7 +109,10 @@ class KubernetesProviderConfig(ProviderConfig):
v.Required('context'): str, v.Required('context'): str,
'launch-retries': int, 'launch-retries': int,
} }
return v.Schema(provider)
schema = ProviderConfig.getCommonSchemaDict()
schema.update(provider)
return v.Schema(schema)
def getSupportedLabels(self, pool_name=None): def getSupportedLabels(self, pool_name=None):
labels = set() labels = set()

View File

@ -385,7 +385,8 @@ class OpenStackProviderConfig(ProviderConfig):
'security-groups': [str] 'security-groups': [str]
}) })
return v.Schema({ schema = ProviderConfig.getCommonSchemaDict()
schema.update({
'region-name': str, 'region-name': str,
v.Required('cloud'): str, v.Required('cloud'): str,
'boot-timeout': int, 'boot-timeout': int,
@ -400,6 +401,7 @@ class OpenStackProviderConfig(ProviderConfig):
'diskimages': [provider_diskimage], 'diskimages': [provider_diskimage],
'cloud-images': [provider_cloud_images], 'cloud-images': [provider_cloud_images],
}) })
return v.Schema(schema)
def getSupportedLabels(self, pool_name=None): def getSupportedLabels(self, pool_name=None):
labels = set() labels = set()

View File

@ -112,7 +112,9 @@ class StaticProviderConfig(ProviderConfig):
'name': str, 'name': str,
'nodes': [pool_node], 'nodes': [pool_node],
}) })
return v.Schema({'pools': [pool]}) schema = ProviderConfig.getCommonSchemaDict()
schema.update({'pools': [pool]})
return v.Schema(schema)
def getSupportedLabels(self, pool_name=None): def getSupportedLabels(self, pool_name=None):
labels = set() labels = set()

View File

@ -58,7 +58,9 @@ class TestConfig(ProviderConfig):
'name': str, 'name': str,
'labels': [str] 'labels': [str]
}) })
return v.Schema({'pools': [pool]}) schema = ProviderConfig.getCommonSchemaDict()
schema.update({'pools': [pool]})
return v.Schema(schema)
def getSupportedLabels(self, pool_name=None): def getSupportedLabels(self, pool_name=None):
return self.labels return self.labels