Merge "Extract out common config parsing for ConfigPool"

This commit is contained in:
Zuul 2019-01-09 18:14:30 +00:00 committed by Gerrit Code Review
commit 1fe5fb60c5
6 changed files with 171 additions and 102 deletions

View File

@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
return not self.__eq__(other) return not self.__eq__(other)
class ConfigPool(ConfigValue): class ConfigPool(ConfigValue, metaclass=abc.ABCMeta):
'''
Base class for a single pool as defined in the configuration file.
'''
def __init__(self): def __init__(self):
self.labels = {} self.labels = {}
self.max_servers = math.inf self.max_servers = math.inf
@ -837,6 +841,40 @@ class ConfigPool(ConfigValue):
self.node_attributes == other.node_attributes) self.node_attributes == other.node_attributes)
return False return False
@classmethod
def getCommonSchemaDict(self):
'''
Return the schema dict for common pool attributes.
When a driver validates its own configuration schema, it should call
this class method to get and include the common pool attributes in
the schema.
The `labels` attribute, though common, can vary its type across
drivers so it is not returned in the schema.
'''
return {
'max-servers': int,
'node-attributes': dict,
}
@abc.abstractmethod
def load(self, pool_config):
'''
Load pool config options from the parsed configuration file.
Subclasses are expected to call the parent method so that common
configuration values are loaded properly.
Although `labels` is a common attribute, each driver may
define it differently, so we cannot parse that attribute here.
:param dict pool_config: A single pool config section from which we
will load the values.
'''
self.max_servers = pool_config.get('max-servers', math.inf)
self.node_attributes = pool_config.get('node-attributes')
class DriverConfig(ConfigValue): class DriverConfig(ConfigValue):
def __init__(self): def __init__(self):

View File

@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool):
def __repr__(self): def __repr__(self):
return "<KubernetesPool %s>" % self.name return "<KubernetesPool %s>" % self.name
def load(self, pool_config, full_config):
super().load(pool_config)
self.name = pool_config['name']
self.labels = {}
for label in pool_config.get('labels', []):
pl = KubernetesLabel()
pl.name = label['name']
pl.type = label['type']
pl.image = label.get('image')
pl.image_pull = label.get('image-pull', 'IfNotPresent')
pl.pool = self
self.labels[pl.name] = pl
full_config.labels[label['name']].pools.append(self)
class KubernetesProviderConfig(ProviderConfig): class KubernetesProviderConfig(ProviderConfig):
def __init__(self, driver, provider): def __init__(self, driver, provider):
@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig):
self.context = self.provider['context'] self.context = self.provider['context']
for pool in self.provider.get('pools', []): for pool in self.provider.get('pools', []):
pp = KubernetesPool() pp = KubernetesPool()
pp.name = pool['name'] pp.load(pool, config)
pp.provider = self pp.provider = self
self.pools[pp.name] = pp self.pools[pp.name] = pp
pp.labels = {}
for label in pool.get('labels', []):
pl = KubernetesLabel()
pl.name = label['name']
pl.type = label['type']
pl.image = label.get('image')
pl.image_pull = label.get('image-pull', 'IfNotPresent')
pl.pool = pp
pp.labels[pl.name] = pl
config.labels[label['name']].pools.append(pp)
def getSchema(self): def getSchema(self):
k8s_label = { k8s_label = {
@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig):
'image-pull': str, 'image-pull': str,
} }
pool = { pool = ConfigPool.getCommonSchemaDict()
pool.update({
v.Required('name'): str, v.Required('name'): str,
v.Required('labels'): [k8s_label], v.Required('labels'): [k8s_label],
} })
provider = { provider = {
v.Required('pools'): [pool], v.Required('pools'): [pool],

View File

@ -149,6 +149,64 @@ class ProviderPool(ConfigPool):
def __repr__(self): def __repr__(self):
return "<ProviderPool %s>" % self.name return "<ProviderPool %s>" % self.name
def load(self, pool_config, full_config, provider):
'''
Load pool configuration options.
:param dict pool_config: A single pool config section from which we
will load the values.
:param dict full_config: The full nodepool config.
:param OpenStackProviderConfig: The calling provider object.
'''
super().load(pool_config)
self.provider = provider
self.name = pool_config['name']
self.max_cores = pool_config.get('max-cores', math.inf)
self.max_ram = pool_config.get('max-ram', math.inf)
self.ignore_provider_quota = pool_config.get('ignore-provider-quota',
False)
self.azs = pool_config.get('availability-zones')
self.networks = pool_config.get('networks', [])
self.security_groups = pool_config.get('security-groups', [])
self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True))
self.host_key_checking = bool(pool_config.get('host-key-checking',
True))
for label in pool_config.get('labels', []):
pl = ProviderLabel()
pl.name = label['name']
pl.pool = self
self.labels[pl.name] = pl
diskimage = label.get('diskimage', None)
if diskimage:
pl.diskimage = full_config.diskimages[diskimage]
else:
pl.diskimage = None
cloud_image_name = label.get('cloud-image', None)
if cloud_image_name:
cloud_image = provider.cloud_images.get(cloud_image_name, None)
if not cloud_image:
raise ValueError(
"cloud-image %s does not exist in provider %s"
" but is referenced in label %s" %
(cloud_image_name, self.name, pl.name))
else:
cloud_image = None
pl.cloud_image = cloud_image
pl.min_ram = label.get('min-ram', 0)
pl.flavor_name = label.get('flavor-name', None)
pl.key_name = label.get('key-name')
pl.console_log = label.get('console-log', False)
pl.boot_from_volume = bool(label.get('boot-from-volume',
False))
pl.volume_size = label.get('volume-size', 50)
pl.instance_properties = label.get('instance-properties',
None)
top_label = full_config.labels[pl.name]
top_label.pools.append(self)
class OpenStackProviderConfig(ProviderConfig): class OpenStackProviderConfig(ProviderConfig):
def __init__(self, driver, provider): def __init__(self, driver, provider):
@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig):
for pool in self.provider.get('pools', []): for pool in self.provider.get('pools', []):
pp = ProviderPool() pp = ProviderPool()
pp.name = pool['name'] pp.load(pool, config, self)
pp.provider = self
self.pools[pp.name] = pp self.pools[pp.name] = pp
pp.max_cores = pool.get('max-cores', math.inf)
pp.max_servers = pool.get('max-servers', math.inf)
pp.max_ram = pool.get('max-ram', math.inf)
pp.ignore_provider_quota = pool.get('ignore-provider-quota', False)
pp.azs = pool.get('availability-zones')
pp.networks = pool.get('networks', [])
pp.security_groups = pool.get('security-groups', [])
pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True))
pp.host_key_checking = bool(pool.get('host-key-checking', True))
pp.node_attributes = pool.get('node-attributes')
for label in pool.get('labels', []):
pl = ProviderLabel()
pl.name = label['name']
pl.pool = pp
pp.labels[pl.name] = pl
diskimage = label.get('diskimage', None)
if diskimage:
pl.diskimage = config.diskimages[diskimage]
else:
pl.diskimage = None
cloud_image_name = label.get('cloud-image', None)
if cloud_image_name:
cloud_image = self.cloud_images.get(cloud_image_name, None)
if not cloud_image:
raise ValueError(
"cloud-image %s does not exist in provider %s"
" but is referenced in label %s" %
(cloud_image_name, self.name, pl.name))
else:
cloud_image = None
pl.cloud_image = cloud_image
pl.min_ram = label.get('min-ram', 0)
pl.flavor_name = label.get('flavor-name', None)
pl.key_name = label.get('key-name')
pl.console_log = label.get('console-log', False)
pl.boot_from_volume = bool(label.get('boot-from-volume',
False))
pl.volume_size = label.get('volume-size', 50)
pl.instance_properties = label.get('instance-properties',
None)
top_label = config.labels[pl.name]
top_label.pools.append(pp)
def getSchema(self): def getSchema(self):
provider_diskimage = { provider_diskimage = {
@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig):
v.Any(label_min_ram, label_flavor_name), v.Any(label_min_ram, label_flavor_name),
v.Any(label_diskimage, label_cloud_image)) v.Any(label_diskimage, label_cloud_image))
pool = { pool = ConfigPool.getCommonSchemaDict()
pool.update({
'name': str, 'name': str,
'networks': [str], 'networks': [str],
'auto-floating-ip': bool, 'auto-floating-ip': bool,
'host-key-checking': bool, 'host-key-checking': bool,
'ignore-provider-quota': bool, 'ignore-provider-quota': bool,
'max-cores': int, 'max-cores': int,
'max-servers': int,
'max-ram': int, 'max-ram': int,
'labels': [pool_label], 'labels': [pool_label],
'node-attributes': dict,
'availability-zones': [str], 'availability-zones': [str],
'security-groups': [str] 'security-groups': [str]
} })
return v.Schema({ return v.Schema({
'region-name': str, 'region-name': str,

View File

@ -41,6 +41,33 @@ class StaticPool(ConfigPool):
def __repr__(self): def __repr__(self):
return "<StaticPool %s>" % self.name return "<StaticPool %s>" % self.name
def load(self, pool_config, full_config):
super().load(pool_config)
self.name = pool_config['name']
# WARNING: This intentionally changes the type!
self.labels = set()
for node in pool_config.get('nodes', []):
self.nodes.append({
'name': node['name'],
'labels': as_list(node['labels']),
'host-key': as_list(node.get('host-key', [])),
'timeout': int(node.get('timeout', 5)),
# Read ssh-port values for backward compat, but prefer port
'connection-port': int(
node.get('connection-port', node.get('ssh-port', 22))),
'connection-type': node.get('connection-type', 'ssh'),
'username': node.get('username', 'zuul'),
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
})
if isinstance(node['labels'], str):
for label in node['labels'].split():
self.labels.add(label)
full_config.labels[label].pools.append(self)
elif isinstance(node['labels'], list):
for label in node['labels']:
self.labels.add(label)
full_config.labels[label].pools.append(self)
class StaticProviderConfig(ProviderConfig): class StaticProviderConfig(ProviderConfig):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig):
def load(self, config): def load(self, config):
for pool in self.provider.get('pools', []): for pool in self.provider.get('pools', []):
pp = StaticPool() pp = StaticPool()
pp.name = pool['name'] pp.load(pool, config)
pp.provider = self pp.provider = self
self.pools[pp.name] = pp self.pools[pp.name] = pp
# WARNING: This intentionally changes the type!
pp.labels = set()
for node in pool.get('nodes', []):
pp.nodes.append({
'name': node['name'],
'labels': as_list(node['labels']),
'host-key': as_list(node.get('host-key', [])),
'timeout': int(node.get('timeout', 5)),
# Read ssh-port values for backward compat, but prefer port
'connection-port': int(
node.get('connection-port', node.get('ssh-port', 22))),
'connection-type': node.get('connection-type', 'ssh'),
'username': node.get('username', 'zuul'),
'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
})
if isinstance(node['labels'], str):
for label in node['labels'].split():
pp.labels.add(label)
config.labels[label].pools.append(pp)
elif isinstance(node['labels'], list):
for label in node['labels']:
pp.labels.add(label)
config.labels[label].pools.append(pp)
def getSchema(self): def getSchema(self):
pool_node = { pool_node = {
@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig):
'connection-type': str, 'connection-type': str,
'max-parallel-jobs': int, 'max-parallel-jobs': int,
} }
pool = { pool = ConfigPool.getCommonSchemaDict()
pool.update({
'name': str, 'name': str,
'nodes': [pool_node], 'nodes': [pool_node],
} })
return v.Schema({'pools': [pool]}) return v.Schema({'pools': [pool]})
def getSupportedLabels(self, pool_name=None): def getSupportedLabels(self, pool_name=None):

View File

@ -12,7 +12,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import math
import voluptuous as v import voluptuous as v
from nodepool.driver import ConfigPool from nodepool.driver import ConfigPool
@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig
class TestPool(ConfigPool): class TestPool(ConfigPool):
pass def load(self, pool_config):
super().load(pool_config)
self.name = pool_config['name']
self.labels = pool_config['labels']
class TestConfig(ProviderConfig): class TestConfig(ProviderConfig):
@ -43,18 +45,19 @@ class TestConfig(ProviderConfig):
self.labels = set() self.labels = set()
for pool in self.provider.get('pools', []): for pool in self.provider.get('pools', []):
testpool = TestPool() testpool = TestPool()
testpool.name = pool['name'] testpool.load(pool)
testpool.provider = self testpool.provider = self
testpool.max_servers = pool.get('max-servers', math.inf)
testpool.labels = pool['labels']
for label in pool['labels']: for label in pool['labels']:
self.labels.add(label) self.labels.add(label)
newconfig.labels[label].pools.append(testpool) newconfig.labels[label].pools.append(testpool)
self.pools[pool['name']] = testpool self.pools[pool['name']] = testpool
def getSchema(self): def getSchema(self):
pool = {'name': str, pool = ConfigPool.getCommonSchemaDict()
'labels': [str]} pool.update({
'name': str,
'labels': [str]
})
return v.Schema({'pools': [pool]}) return v.Schema({'pools': [pool]})
def getSupportedLabels(self, pool_name=None): def getSupportedLabels(self, pool_name=None):

View File

@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool
from nodepool.driver.static.config import StaticProviderConfig from nodepool.driver.static.config import StaticProviderConfig
class TempConfigPool(ConfigPool):
def load(self):
pass
class TestConfigComparisons(tests.BaseTestCase): class TestConfigComparisons(tests.BaseTestCase):
def test_ConfigPool(self): def test_ConfigPool(self):
a = ConfigPool()
b = ConfigPool() a = TempConfigPool()
b = TempConfigPool()
self.assertEqual(a, b) self.assertEqual(a, b)
a.max_servers = 5 a.max_servers = 5
self.assertNotEqual(a, b) self.assertNotEqual(a, b)
@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase):
a.max_servers = 5 a.max_servers = 5
self.assertNotEqual(a, b) self.assertNotEqual(a, b)
c = ConfigPool() c = TempConfigPool()
d = ProviderPool() d = ProviderPool()
self.assertNotEqual(c, d) self.assertNotEqual(d, c)
def test_OpenStackProviderConfig(self): def test_OpenStackProviderConfig(self):
provider = {'name': 'foo'} provider = {'name': 'foo'}
@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase):
# intentionally change an attribute of the base class # intentionally change an attribute of the base class
a.max_servers = 5 a.max_servers = 5
self.assertNotEqual(a, b) self.assertNotEqual(a, b)
c = ConfigPool() c = TempConfigPool()
self.assertNotEqual(b, c) self.assertNotEqual(b, c)
def test_StaticProviderConfig(self): def test_StaticProviderConfig(self):