Merge "Extract out common config parsing for ConfigPool"
This commit is contained in:
commit
1fe5fb60c5
|
@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
|
||||||
class ConfigPool(ConfigValue):
|
class ConfigPool(ConfigValue, metaclass=abc.ABCMeta):
|
||||||
|
'''
|
||||||
|
Base class for a single pool as defined in the configuration file.
|
||||||
|
'''
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.labels = {}
|
self.labels = {}
|
||||||
self.max_servers = math.inf
|
self.max_servers = math.inf
|
||||||
|
@ -837,6 +841,40 @@ class ConfigPool(ConfigValue):
|
||||||
self.node_attributes == other.node_attributes)
|
self.node_attributes == other.node_attributes)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getCommonSchemaDict(self):
|
||||||
|
'''
|
||||||
|
Return the schema dict for common pool attributes.
|
||||||
|
|
||||||
|
When a driver validates its own configuration schema, it should call
|
||||||
|
this class method to get and include the common pool attributes in
|
||||||
|
the schema.
|
||||||
|
|
||||||
|
The `labels` attribute, though common, can vary its type across
|
||||||
|
drivers so it is not returned in the schema.
|
||||||
|
'''
|
||||||
|
return {
|
||||||
|
'max-servers': int,
|
||||||
|
'node-attributes': dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self, pool_config):
|
||||||
|
'''
|
||||||
|
Load pool config options from the parsed configuration file.
|
||||||
|
|
||||||
|
Subclasses are expected to call the parent method so that common
|
||||||
|
configuration values are loaded properly.
|
||||||
|
|
||||||
|
Although `labels` is a common attribute, each driver may
|
||||||
|
define it differently, so we cannot parse that attribute here.
|
||||||
|
|
||||||
|
:param dict pool_config: A single pool config section from which we
|
||||||
|
will load the values.
|
||||||
|
'''
|
||||||
|
self.max_servers = pool_config.get('max-servers', math.inf)
|
||||||
|
self.node_attributes = pool_config.get('node-attributes')
|
||||||
|
|
||||||
|
|
||||||
class DriverConfig(ConfigValue):
|
class DriverConfig(ConfigValue):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<KubernetesPool %s>" % self.name
|
return "<KubernetesPool %s>" % self.name
|
||||||
|
|
||||||
|
def load(self, pool_config, full_config):
|
||||||
|
super().load(pool_config)
|
||||||
|
self.name = pool_config['name']
|
||||||
|
self.labels = {}
|
||||||
|
for label in pool_config.get('labels', []):
|
||||||
|
pl = KubernetesLabel()
|
||||||
|
pl.name = label['name']
|
||||||
|
pl.type = label['type']
|
||||||
|
pl.image = label.get('image')
|
||||||
|
pl.image_pull = label.get('image-pull', 'IfNotPresent')
|
||||||
|
pl.pool = self
|
||||||
|
self.labels[pl.name] = pl
|
||||||
|
full_config.labels[label['name']].pools.append(self)
|
||||||
|
|
||||||
|
|
||||||
class KubernetesProviderConfig(ProviderConfig):
|
class KubernetesProviderConfig(ProviderConfig):
|
||||||
def __init__(self, driver, provider):
|
def __init__(self, driver, provider):
|
||||||
|
@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig):
|
||||||
self.context = self.provider['context']
|
self.context = self.provider['context']
|
||||||
for pool in self.provider.get('pools', []):
|
for pool in self.provider.get('pools', []):
|
||||||
pp = KubernetesPool()
|
pp = KubernetesPool()
|
||||||
pp.name = pool['name']
|
pp.load(pool, config)
|
||||||
pp.provider = self
|
pp.provider = self
|
||||||
self.pools[pp.name] = pp
|
self.pools[pp.name] = pp
|
||||||
pp.labels = {}
|
|
||||||
for label in pool.get('labels', []):
|
|
||||||
pl = KubernetesLabel()
|
|
||||||
pl.name = label['name']
|
|
||||||
pl.type = label['type']
|
|
||||||
pl.image = label.get('image')
|
|
||||||
pl.image_pull = label.get('image-pull', 'IfNotPresent')
|
|
||||||
pl.pool = pp
|
|
||||||
pp.labels[pl.name] = pl
|
|
||||||
config.labels[label['name']].pools.append(pp)
|
|
||||||
|
|
||||||
def getSchema(self):
|
def getSchema(self):
|
||||||
k8s_label = {
|
k8s_label = {
|
||||||
|
@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig):
|
||||||
'image-pull': str,
|
'image-pull': str,
|
||||||
}
|
}
|
||||||
|
|
||||||
pool = {
|
pool = ConfigPool.getCommonSchemaDict()
|
||||||
|
pool.update({
|
||||||
v.Required('name'): str,
|
v.Required('name'): str,
|
||||||
v.Required('labels'): [k8s_label],
|
v.Required('labels'): [k8s_label],
|
||||||
}
|
})
|
||||||
|
|
||||||
provider = {
|
provider = {
|
||||||
v.Required('pools'): [pool],
|
v.Required('pools'): [pool],
|
||||||
|
|
|
@ -149,6 +149,64 @@ class ProviderPool(ConfigPool):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<ProviderPool %s>" % self.name
|
return "<ProviderPool %s>" % self.name
|
||||||
|
|
||||||
|
def load(self, pool_config, full_config, provider):
|
||||||
|
'''
|
||||||
|
Load pool configuration options.
|
||||||
|
|
||||||
|
:param dict pool_config: A single pool config section from which we
|
||||||
|
will load the values.
|
||||||
|
:param dict full_config: The full nodepool config.
|
||||||
|
:param OpenStackProviderConfig: The calling provider object.
|
||||||
|
'''
|
||||||
|
super().load(pool_config)
|
||||||
|
|
||||||
|
self.provider = provider
|
||||||
|
self.name = pool_config['name']
|
||||||
|
self.max_cores = pool_config.get('max-cores', math.inf)
|
||||||
|
self.max_ram = pool_config.get('max-ram', math.inf)
|
||||||
|
self.ignore_provider_quota = pool_config.get('ignore-provider-quota',
|
||||||
|
False)
|
||||||
|
self.azs = pool_config.get('availability-zones')
|
||||||
|
self.networks = pool_config.get('networks', [])
|
||||||
|
self.security_groups = pool_config.get('security-groups', [])
|
||||||
|
self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True))
|
||||||
|
self.host_key_checking = bool(pool_config.get('host-key-checking',
|
||||||
|
True))
|
||||||
|
|
||||||
|
for label in pool_config.get('labels', []):
|
||||||
|
pl = ProviderLabel()
|
||||||
|
pl.name = label['name']
|
||||||
|
pl.pool = self
|
||||||
|
self.labels[pl.name] = pl
|
||||||
|
diskimage = label.get('diskimage', None)
|
||||||
|
if diskimage:
|
||||||
|
pl.diskimage = full_config.diskimages[diskimage]
|
||||||
|
else:
|
||||||
|
pl.diskimage = None
|
||||||
|
cloud_image_name = label.get('cloud-image', None)
|
||||||
|
if cloud_image_name:
|
||||||
|
cloud_image = provider.cloud_images.get(cloud_image_name, None)
|
||||||
|
if not cloud_image:
|
||||||
|
raise ValueError(
|
||||||
|
"cloud-image %s does not exist in provider %s"
|
||||||
|
" but is referenced in label %s" %
|
||||||
|
(cloud_image_name, self.name, pl.name))
|
||||||
|
else:
|
||||||
|
cloud_image = None
|
||||||
|
pl.cloud_image = cloud_image
|
||||||
|
pl.min_ram = label.get('min-ram', 0)
|
||||||
|
pl.flavor_name = label.get('flavor-name', None)
|
||||||
|
pl.key_name = label.get('key-name')
|
||||||
|
pl.console_log = label.get('console-log', False)
|
||||||
|
pl.boot_from_volume = bool(label.get('boot-from-volume',
|
||||||
|
False))
|
||||||
|
pl.volume_size = label.get('volume-size', 50)
|
||||||
|
pl.instance_properties = label.get('instance-properties',
|
||||||
|
None)
|
||||||
|
|
||||||
|
top_label = full_config.labels[pl.name]
|
||||||
|
top_label.pools.append(self)
|
||||||
|
|
||||||
|
|
||||||
class OpenStackProviderConfig(ProviderConfig):
|
class OpenStackProviderConfig(ProviderConfig):
|
||||||
def __init__(self, driver, provider):
|
def __init__(self, driver, provider):
|
||||||
|
@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig):
|
||||||
|
|
||||||
for pool in self.provider.get('pools', []):
|
for pool in self.provider.get('pools', []):
|
||||||
pp = ProviderPool()
|
pp = ProviderPool()
|
||||||
pp.name = pool['name']
|
pp.load(pool, config, self)
|
||||||
pp.provider = self
|
|
||||||
self.pools[pp.name] = pp
|
self.pools[pp.name] = pp
|
||||||
pp.max_cores = pool.get('max-cores', math.inf)
|
|
||||||
pp.max_servers = pool.get('max-servers', math.inf)
|
|
||||||
pp.max_ram = pool.get('max-ram', math.inf)
|
|
||||||
pp.ignore_provider_quota = pool.get('ignore-provider-quota', False)
|
|
||||||
pp.azs = pool.get('availability-zones')
|
|
||||||
pp.networks = pool.get('networks', [])
|
|
||||||
pp.security_groups = pool.get('security-groups', [])
|
|
||||||
pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True))
|
|
||||||
pp.host_key_checking = bool(pool.get('host-key-checking', True))
|
|
||||||
pp.node_attributes = pool.get('node-attributes')
|
|
||||||
|
|
||||||
for label in pool.get('labels', []):
|
|
||||||
pl = ProviderLabel()
|
|
||||||
pl.name = label['name']
|
|
||||||
pl.pool = pp
|
|
||||||
pp.labels[pl.name] = pl
|
|
||||||
diskimage = label.get('diskimage', None)
|
|
||||||
if diskimage:
|
|
||||||
pl.diskimage = config.diskimages[diskimage]
|
|
||||||
else:
|
|
||||||
pl.diskimage = None
|
|
||||||
cloud_image_name = label.get('cloud-image', None)
|
|
||||||
if cloud_image_name:
|
|
||||||
cloud_image = self.cloud_images.get(cloud_image_name, None)
|
|
||||||
if not cloud_image:
|
|
||||||
raise ValueError(
|
|
||||||
"cloud-image %s does not exist in provider %s"
|
|
||||||
" but is referenced in label %s" %
|
|
||||||
(cloud_image_name, self.name, pl.name))
|
|
||||||
else:
|
|
||||||
cloud_image = None
|
|
||||||
pl.cloud_image = cloud_image
|
|
||||||
pl.min_ram = label.get('min-ram', 0)
|
|
||||||
pl.flavor_name = label.get('flavor-name', None)
|
|
||||||
pl.key_name = label.get('key-name')
|
|
||||||
pl.console_log = label.get('console-log', False)
|
|
||||||
pl.boot_from_volume = bool(label.get('boot-from-volume',
|
|
||||||
False))
|
|
||||||
pl.volume_size = label.get('volume-size', 50)
|
|
||||||
pl.instance_properties = label.get('instance-properties',
|
|
||||||
None)
|
|
||||||
|
|
||||||
top_label = config.labels[pl.name]
|
|
||||||
top_label.pools.append(pp)
|
|
||||||
|
|
||||||
def getSchema(self):
|
def getSchema(self):
|
||||||
provider_diskimage = {
|
provider_diskimage = {
|
||||||
|
@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig):
|
||||||
v.Any(label_min_ram, label_flavor_name),
|
v.Any(label_min_ram, label_flavor_name),
|
||||||
v.Any(label_diskimage, label_cloud_image))
|
v.Any(label_diskimage, label_cloud_image))
|
||||||
|
|
||||||
pool = {
|
pool = ConfigPool.getCommonSchemaDict()
|
||||||
|
pool.update({
|
||||||
'name': str,
|
'name': str,
|
||||||
'networks': [str],
|
'networks': [str],
|
||||||
'auto-floating-ip': bool,
|
'auto-floating-ip': bool,
|
||||||
'host-key-checking': bool,
|
'host-key-checking': bool,
|
||||||
'ignore-provider-quota': bool,
|
'ignore-provider-quota': bool,
|
||||||
'max-cores': int,
|
'max-cores': int,
|
||||||
'max-servers': int,
|
|
||||||
'max-ram': int,
|
'max-ram': int,
|
||||||
'labels': [pool_label],
|
'labels': [pool_label],
|
||||||
'node-attributes': dict,
|
|
||||||
'availability-zones': [str],
|
'availability-zones': [str],
|
||||||
'security-groups': [str]
|
'security-groups': [str]
|
||||||
}
|
})
|
||||||
|
|
||||||
return v.Schema({
|
return v.Schema({
|
||||||
'region-name': str,
|
'region-name': str,
|
||||||
|
|
|
@ -41,6 +41,33 @@ class StaticPool(ConfigPool):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<StaticPool %s>" % self.name
|
return "<StaticPool %s>" % self.name
|
||||||
|
|
||||||
|
def load(self, pool_config, full_config):
|
||||||
|
super().load(pool_config)
|
||||||
|
self.name = pool_config['name']
|
||||||
|
# WARNING: This intentionally changes the type!
|
||||||
|
self.labels = set()
|
||||||
|
for node in pool_config.get('nodes', []):
|
||||||
|
self.nodes.append({
|
||||||
|
'name': node['name'],
|
||||||
|
'labels': as_list(node['labels']),
|
||||||
|
'host-key': as_list(node.get('host-key', [])),
|
||||||
|
'timeout': int(node.get('timeout', 5)),
|
||||||
|
# Read ssh-port values for backward compat, but prefer port
|
||||||
|
'connection-port': int(
|
||||||
|
node.get('connection-port', node.get('ssh-port', 22))),
|
||||||
|
'connection-type': node.get('connection-type', 'ssh'),
|
||||||
|
'username': node.get('username', 'zuul'),
|
||||||
|
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
|
||||||
|
})
|
||||||
|
if isinstance(node['labels'], str):
|
||||||
|
for label in node['labels'].split():
|
||||||
|
self.labels.add(label)
|
||||||
|
full_config.labels[label].pools.append(self)
|
||||||
|
elif isinstance(node['labels'], list):
|
||||||
|
for label in node['labels']:
|
||||||
|
self.labels.add(label)
|
||||||
|
full_config.labels[label].pools.append(self)
|
||||||
|
|
||||||
|
|
||||||
class StaticProviderConfig(ProviderConfig):
|
class StaticProviderConfig(ProviderConfig):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig):
|
||||||
def load(self, config):
|
def load(self, config):
|
||||||
for pool in self.provider.get('pools', []):
|
for pool in self.provider.get('pools', []):
|
||||||
pp = StaticPool()
|
pp = StaticPool()
|
||||||
pp.name = pool['name']
|
pp.load(pool, config)
|
||||||
pp.provider = self
|
pp.provider = self
|
||||||
self.pools[pp.name] = pp
|
self.pools[pp.name] = pp
|
||||||
# WARNING: This intentionally changes the type!
|
|
||||||
pp.labels = set()
|
|
||||||
for node in pool.get('nodes', []):
|
|
||||||
pp.nodes.append({
|
|
||||||
'name': node['name'],
|
|
||||||
'labels': as_list(node['labels']),
|
|
||||||
'host-key': as_list(node.get('host-key', [])),
|
|
||||||
'timeout': int(node.get('timeout', 5)),
|
|
||||||
# Read ssh-port values for backward compat, but prefer port
|
|
||||||
'connection-port': int(
|
|
||||||
node.get('connection-port', node.get('ssh-port', 22))),
|
|
||||||
'connection-type': node.get('connection-type', 'ssh'),
|
|
||||||
'username': node.get('username', 'zuul'),
|
|
||||||
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
|
|
||||||
})
|
|
||||||
if isinstance(node['labels'], str):
|
|
||||||
for label in node['labels'].split():
|
|
||||||
pp.labels.add(label)
|
|
||||||
config.labels[label].pools.append(pp)
|
|
||||||
elif isinstance(node['labels'], list):
|
|
||||||
for label in node['labels']:
|
|
||||||
pp.labels.add(label)
|
|
||||||
config.labels[label].pools.append(pp)
|
|
||||||
|
|
||||||
def getSchema(self):
|
def getSchema(self):
|
||||||
pool_node = {
|
pool_node = {
|
||||||
|
@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig):
|
||||||
'connection-type': str,
|
'connection-type': str,
|
||||||
'max-parallel-jobs': int,
|
'max-parallel-jobs': int,
|
||||||
}
|
}
|
||||||
pool = {
|
pool = ConfigPool.getCommonSchemaDict()
|
||||||
|
pool.update({
|
||||||
'name': str,
|
'name': str,
|
||||||
'nodes': [pool_node],
|
'nodes': [pool_node],
|
||||||
}
|
})
|
||||||
return v.Schema({'pools': [pool]})
|
return v.Schema({'pools': [pool]})
|
||||||
|
|
||||||
def getSupportedLabels(self, pool_name=None):
|
def getSupportedLabels(self, pool_name=None):
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import math
|
|
||||||
import voluptuous as v
|
import voluptuous as v
|
||||||
|
|
||||||
from nodepool.driver import ConfigPool
|
from nodepool.driver import ConfigPool
|
||||||
|
@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig
|
||||||
|
|
||||||
|
|
||||||
class TestPool(ConfigPool):
|
class TestPool(ConfigPool):
|
||||||
pass
|
def load(self, pool_config):
|
||||||
|
super().load(pool_config)
|
||||||
|
self.name = pool_config['name']
|
||||||
|
self.labels = pool_config['labels']
|
||||||
|
|
||||||
|
|
||||||
class TestConfig(ProviderConfig):
|
class TestConfig(ProviderConfig):
|
||||||
|
@ -43,18 +45,19 @@ class TestConfig(ProviderConfig):
|
||||||
self.labels = set()
|
self.labels = set()
|
||||||
for pool in self.provider.get('pools', []):
|
for pool in self.provider.get('pools', []):
|
||||||
testpool = TestPool()
|
testpool = TestPool()
|
||||||
testpool.name = pool['name']
|
testpool.load(pool)
|
||||||
testpool.provider = self
|
testpool.provider = self
|
||||||
testpool.max_servers = pool.get('max-servers', math.inf)
|
|
||||||
testpool.labels = pool['labels']
|
|
||||||
for label in pool['labels']:
|
for label in pool['labels']:
|
||||||
self.labels.add(label)
|
self.labels.add(label)
|
||||||
newconfig.labels[label].pools.append(testpool)
|
newconfig.labels[label].pools.append(testpool)
|
||||||
self.pools[pool['name']] = testpool
|
self.pools[pool['name']] = testpool
|
||||||
|
|
||||||
def getSchema(self):
|
def getSchema(self):
|
||||||
pool = {'name': str,
|
pool = ConfigPool.getCommonSchemaDict()
|
||||||
'labels': [str]}
|
pool.update({
|
||||||
|
'name': str,
|
||||||
|
'labels': [str]
|
||||||
|
})
|
||||||
return v.Schema({'pools': [pool]})
|
return v.Schema({'pools': [pool]})
|
||||||
|
|
||||||
def getSupportedLabels(self, pool_name=None):
|
def getSupportedLabels(self, pool_name=None):
|
||||||
|
|
|
@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool
|
||||||
from nodepool.driver.static.config import StaticProviderConfig
|
from nodepool.driver.static.config import StaticProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TempConfigPool(ConfigPool):
|
||||||
|
def load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestConfigComparisons(tests.BaseTestCase):
|
class TestConfigComparisons(tests.BaseTestCase):
|
||||||
|
|
||||||
def test_ConfigPool(self):
|
def test_ConfigPool(self):
|
||||||
a = ConfigPool()
|
|
||||||
b = ConfigPool()
|
a = TempConfigPool()
|
||||||
|
b = TempConfigPool()
|
||||||
self.assertEqual(a, b)
|
self.assertEqual(a, b)
|
||||||
a.max_servers = 5
|
a.max_servers = 5
|
||||||
self.assertNotEqual(a, b)
|
self.assertNotEqual(a, b)
|
||||||
|
@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase):
|
||||||
a.max_servers = 5
|
a.max_servers = 5
|
||||||
self.assertNotEqual(a, b)
|
self.assertNotEqual(a, b)
|
||||||
|
|
||||||
c = ConfigPool()
|
c = TempConfigPool()
|
||||||
d = ProviderPool()
|
d = ProviderPool()
|
||||||
self.assertNotEqual(c, d)
|
self.assertNotEqual(d, c)
|
||||||
|
|
||||||
def test_OpenStackProviderConfig(self):
|
def test_OpenStackProviderConfig(self):
|
||||||
provider = {'name': 'foo'}
|
provider = {'name': 'foo'}
|
||||||
|
@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase):
|
||||||
# intentionally change an attribute of the base class
|
# intentionally change an attribute of the base class
|
||||||
a.max_servers = 5
|
a.max_servers = 5
|
||||||
self.assertNotEqual(a, b)
|
self.assertNotEqual(a, b)
|
||||||
c = ConfigPool()
|
c = TempConfigPool()
|
||||||
self.assertNotEqual(b, c)
|
self.assertNotEqual(b, c)
|
||||||
|
|
||||||
def test_StaticProviderConfig(self):
|
def test_StaticProviderConfig(self):
|
||||||
|
|
Loading…
Reference in New Issue