Browse Source

Extract out common config parsing for ConfigPool

Our driver code is in a less-than-ideal situation where each driver
is responsible for parsing config options that are common to all
drivers. This change begins to correct that, starting with ConfigPool.
It changes the driver API in the following ways:

1) Forces objects derived from ConfigPool to implement a load() method
   that should call super's method, then handle loading driver specific
   options from the config.

2) Adds a ConfigPool class method that can be called to get the config
   schema for the common config options leaving drivers to have to only
   define the schema for their own config options.

Other base config objects will be modeled after this pattern in
later changes.

Change-Id: I41620590c355cacd2c4fbe6916acfe80f20e3216
tags/3.5.0
David Shrewsbury 5 months ago
parent
commit
a19dffd916

+ 39
- 1
nodepool/driver/__init__.py View File

@@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
824 824
         return not self.__eq__(other)
825 825
 
826 826
 
827
-class ConfigPool(ConfigValue):
827
+class ConfigPool(ConfigValue, metaclass=abc.ABCMeta):
828
+    '''
829
+    Base class for a single pool as defined in the configuration file.
830
+    '''
831
+
828 832
     def __init__(self):
829 833
         self.labels = {}
830 834
         self.max_servers = math.inf
@@ -837,6 +841,40 @@ class ConfigPool(ConfigValue):
837 841
                     self.node_attributes == other.node_attributes)
838 842
         return False
839 843
 
844
+    @classmethod
845
+    def getCommonSchemaDict(self):
846
+        '''
847
+        Return the schema dict for common pool attributes.
848
+
849
+        When a driver validates its own configuration schema, it should call
850
+        this class method to get and include the common pool attributes in
851
+        the schema.
852
+
853
+        The `labels` attribute, though common, can vary its type across
854
+        drivers so it is not returned in the schema.
855
+        '''
856
+        return {
857
+            'max-servers': int,
858
+            'node-attributes': dict,
859
+        }
860
+
861
+    @abc.abstractmethod
862
+    def load(self, pool_config):
863
+        '''
864
+        Load pool config options from the parsed configuration file.
865
+
866
+        Subclasses are expected to call the parent method so that common
867
+        configuration values are loaded properly.
868
+
869
+        Although `labels` is a common attribute, each driver may
870
+        define it differently, so we cannot parse that attribute here.
871
+
872
+        :param dict pool_config: A single pool config section from which we
873
+            will load the values.
874
+        '''
875
+        self.max_servers = pool_config.get('max-servers', math.inf)
876
+        self.node_attributes = pool_config.get('node-attributes')
877
+
840 878
 
841 879
 class DriverConfig(ConfigValue):
842 880
     def __init__(self):

+ 18
- 13
nodepool/driver/kubernetes/config.py View File

@@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool):
45 45
     def __repr__(self):
46 46
         return "<KubernetesPool %s>" % self.name
47 47
 
48
+    def load(self, pool_config, full_config):
49
+        super().load(pool_config)
50
+        self.name = pool_config['name']
51
+        self.labels = {}
52
+        for label in pool_config.get('labels', []):
53
+            pl = KubernetesLabel()
54
+            pl.name = label['name']
55
+            pl.type = label['type']
56
+            pl.image = label.get('image')
57
+            pl.image_pull = label.get('image-pull', 'IfNotPresent')
58
+            pl.pool = self
59
+            self.labels[pl.name] = pl
60
+            full_config.labels[label['name']].pools.append(self)
61
+
48 62
 
49 63
 class KubernetesProviderConfig(ProviderConfig):
50 64
     def __init__(self, driver, provider):
@@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig):
72 86
         self.context = self.provider['context']
73 87
         for pool in self.provider.get('pools', []):
74 88
             pp = KubernetesPool()
75
-            pp.name = pool['name']
89
+            pp.load(pool, config)
76 90
             pp.provider = self
77 91
             self.pools[pp.name] = pp
78
-            pp.labels = {}
79
-            for label in pool.get('labels', []):
80
-                pl = KubernetesLabel()
81
-                pl.name = label['name']
82
-                pl.type = label['type']
83
-                pl.image = label.get('image')
84
-                pl.image_pull = label.get('image-pull', 'IfNotPresent')
85
-                pl.pool = pp
86
-                pp.labels[pl.name] = pl
87
-                config.labels[label['name']].pools.append(pp)
88 92
 
89 93
     def getSchema(self):
90 94
         k8s_label = {
@@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig):
94 98
             'image-pull': str,
95 99
         }
96 100
 
97
-        pool = {
101
+        pool = ConfigPool.getCommonSchemaDict()
102
+        pool.update({
98 103
             v.Required('name'): str,
99 104
             v.Required('labels'): [k8s_label],
100
-        }
105
+        })
101 106
 
102 107
         provider = {
103 108
             v.Required('pools'): [pool],

+ 62
- 50
nodepool/driver/openstack/config.py View File

@@ -149,6 +149,64 @@ class ProviderPool(ConfigPool):
149 149
     def __repr__(self):
150 150
         return "<ProviderPool %s>" % self.name
151 151
 
152
+    def load(self, pool_config, full_config, provider):
153
+        '''
154
+        Load pool configuration options.
155
+
156
+        :param dict pool_config: A single pool config section from which we
157
+            will load the values.
158
+        :param dict full_config: The full nodepool config.
159
+        :param OpenStackProviderConfig: The calling provider object.
160
+        '''
161
+        super().load(pool_config)
162
+
163
+        self.provider = provider
164
+        self.name = pool_config['name']
165
+        self.max_cores = pool_config.get('max-cores', math.inf)
166
+        self.max_ram = pool_config.get('max-ram', math.inf)
167
+        self.ignore_provider_quota = pool_config.get('ignore-provider-quota',
168
+                                                     False)
169
+        self.azs = pool_config.get('availability-zones')
170
+        self.networks = pool_config.get('networks', [])
171
+        self.security_groups = pool_config.get('security-groups', [])
172
+        self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True))
173
+        self.host_key_checking = bool(pool_config.get('host-key-checking',
174
+                                                      True))
175
+
176
+        for label in pool_config.get('labels', []):
177
+            pl = ProviderLabel()
178
+            pl.name = label['name']
179
+            pl.pool = self
180
+            self.labels[pl.name] = pl
181
+            diskimage = label.get('diskimage', None)
182
+            if diskimage:
183
+                pl.diskimage = full_config.diskimages[diskimage]
184
+            else:
185
+                pl.diskimage = None
186
+            cloud_image_name = label.get('cloud-image', None)
187
+            if cloud_image_name:
188
+                cloud_image = provider.cloud_images.get(cloud_image_name, None)
189
+                if not cloud_image:
190
+                    raise ValueError(
191
+                        "cloud-image %s does not exist in provider %s"
192
+                        " but is referenced in label %s" %
193
+                        (cloud_image_name, self.name, pl.name))
194
+            else:
195
+                cloud_image = None
196
+            pl.cloud_image = cloud_image
197
+            pl.min_ram = label.get('min-ram', 0)
198
+            pl.flavor_name = label.get('flavor-name', None)
199
+            pl.key_name = label.get('key-name')
200
+            pl.console_log = label.get('console-log', False)
201
+            pl.boot_from_volume = bool(label.get('boot-from-volume',
202
+                                                 False))
203
+            pl.volume_size = label.get('volume-size', 50)
204
+            pl.instance_properties = label.get('instance-properties',
205
+                                               None)
206
+
207
+            top_label = full_config.labels[pl.name]
208
+            top_label.pools.append(self)
209
+
152 210
 
153 211
 class OpenStackProviderConfig(ProviderConfig):
154 212
     def __init__(self, driver, provider):
@@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig):
263 321
 
264 322
         for pool in self.provider.get('pools', []):
265 323
             pp = ProviderPool()
266
-            pp.name = pool['name']
267
-            pp.provider = self
324
+            pp.load(pool, config, self)
268 325
             self.pools[pp.name] = pp
269
-            pp.max_cores = pool.get('max-cores', math.inf)
270
-            pp.max_servers = pool.get('max-servers', math.inf)
271
-            pp.max_ram = pool.get('max-ram', math.inf)
272
-            pp.ignore_provider_quota = pool.get('ignore-provider-quota', False)
273
-            pp.azs = pool.get('availability-zones')
274
-            pp.networks = pool.get('networks', [])
275
-            pp.security_groups = pool.get('security-groups', [])
276
-            pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True))
277
-            pp.host_key_checking = bool(pool.get('host-key-checking', True))
278
-            pp.node_attributes = pool.get('node-attributes')
279
-
280
-            for label in pool.get('labels', []):
281
-                pl = ProviderLabel()
282
-                pl.name = label['name']
283
-                pl.pool = pp
284
-                pp.labels[pl.name] = pl
285
-                diskimage = label.get('diskimage', None)
286
-                if diskimage:
287
-                    pl.diskimage = config.diskimages[diskimage]
288
-                else:
289
-                    pl.diskimage = None
290
-                cloud_image_name = label.get('cloud-image', None)
291
-                if cloud_image_name:
292
-                    cloud_image = self.cloud_images.get(cloud_image_name, None)
293
-                    if not cloud_image:
294
-                        raise ValueError(
295
-                            "cloud-image %s does not exist in provider %s"
296
-                            " but is referenced in label %s" %
297
-                            (cloud_image_name, self.name, pl.name))
298
-                else:
299
-                    cloud_image = None
300
-                pl.cloud_image = cloud_image
301
-                pl.min_ram = label.get('min-ram', 0)
302
-                pl.flavor_name = label.get('flavor-name', None)
303
-                pl.key_name = label.get('key-name')
304
-                pl.console_log = label.get('console-log', False)
305
-                pl.boot_from_volume = bool(label.get('boot-from-volume',
306
-                                                     False))
307
-                pl.volume_size = label.get('volume-size', 50)
308
-                pl.instance_properties = label.get('instance-properties',
309
-                                                   None)
310
-
311
-                top_label = config.labels[pl.name]
312
-                top_label.pools.append(pp)
313 326
 
314 327
     def getSchema(self):
315 328
         provider_diskimage = {
@@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig):
358 371
                            v.Any(label_min_ram, label_flavor_name),
359 372
                            v.Any(label_diskimage, label_cloud_image))
360 373
 
361
-        pool = {
374
+        pool = ConfigPool.getCommonSchemaDict()
375
+        pool.update({
362 376
             'name': str,
363 377
             'networks': [str],
364 378
             'auto-floating-ip': bool,
365 379
             'host-key-checking': bool,
366 380
             'ignore-provider-quota': bool,
367 381
             'max-cores': int,
368
-            'max-servers': int,
369 382
             'max-ram': int,
370 383
             'labels': [pool_label],
371
-            'node-attributes': dict,
372 384
             'availability-zones': [str],
373 385
             'security-groups': [str]
374
-        }
386
+        })
375 387
 
376 388
         return v.Schema({
377 389
             'region-name': str,

+ 31
- 26
nodepool/driver/static/config.py View File

@@ -41,6 +41,33 @@ class StaticPool(ConfigPool):
41 41
     def __repr__(self):
42 42
         return "<StaticPool %s>" % self.name
43 43
 
44
+    def load(self, pool_config, full_config):
45
+        super().load(pool_config)
46
+        self.name = pool_config['name']
47
+        # WARNING: This intentionally changes the type!
48
+        self.labels = set()
49
+        for node in pool_config.get('nodes', []):
50
+            self.nodes.append({
51
+                'name': node['name'],
52
+                'labels': as_list(node['labels']),
53
+                'host-key': as_list(node.get('host-key', [])),
54
+                'timeout': int(node.get('timeout', 5)),
55
+                # Read ssh-port values for backward compat, but prefer port
56
+                'connection-port': int(
57
+                    node.get('connection-port', node.get('ssh-port', 22))),
58
+                'connection-type': node.get('connection-type', 'ssh'),
59
+                'username': node.get('username', 'zuul'),
60
+                'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
61
+            })
62
+            if isinstance(node['labels'], str):
63
+                for label in node['labels'].split():
64
+                    self.labels.add(label)
65
+                    full_config.labels[label].pools.append(self)
66
+            elif isinstance(node['labels'], list):
67
+                for label in node['labels']:
68
+                    self.labels.add(label)
69
+                    full_config.labels[label].pools.append(self)
70
+
44 71
 
45 72
 class StaticProviderConfig(ProviderConfig):
46 73
     def __init__(self, *args, **kwargs):
@@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig):
65 92
     def load(self, config):
66 93
         for pool in self.provider.get('pools', []):
67 94
             pp = StaticPool()
68
-            pp.name = pool['name']
95
+            pp.load(pool, config)
69 96
             pp.provider = self
70 97
             self.pools[pp.name] = pp
71
-            # WARNING: This intentionally changes the type!
72
-            pp.labels = set()
73
-            for node in pool.get('nodes', []):
74
-                pp.nodes.append({
75
-                    'name': node['name'],
76
-                    'labels': as_list(node['labels']),
77
-                    'host-key': as_list(node.get('host-key', [])),
78
-                    'timeout': int(node.get('timeout', 5)),
79
-                    # Read ssh-port values for backward compat, but prefer port
80
-                    'connection-port': int(
81
-                        node.get('connection-port', node.get('ssh-port', 22))),
82
-                    'connection-type': node.get('connection-type', 'ssh'),
83
-                    'username': node.get('username', 'zuul'),
84
-                    'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
85
-                })
86
-                if isinstance(node['labels'], str):
87
-                    for label in node['labels'].split():
88
-                        pp.labels.add(label)
89
-                        config.labels[label].pools.append(pp)
90
-                elif isinstance(node['labels'], list):
91
-                    for label in node['labels']:
92
-                        pp.labels.add(label)
93
-                        config.labels[label].pools.append(pp)
94 98
 
95 99
     def getSchema(self):
96 100
         pool_node = {
@@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig):
103 107
             'connection-type': str,
104 108
             'max-parallel-jobs': int,
105 109
         }
106
-        pool = {
110
+        pool = ConfigPool.getCommonSchemaDict()
111
+        pool.update({
107 112
             'name': str,
108 113
             'nodes': [pool_node],
109
-        }
114
+        })
110 115
         return v.Schema({'pools': [pool]})
111 116
 
112 117
     def getSupportedLabels(self, pool_name=None):

+ 10
- 7
nodepool/driver/test/config.py View File

@@ -12,7 +12,6 @@
12 12
 # License for the specific language governing permissions and limitations
13 13
 # under the License.
14 14
 
15
-import math
16 15
 import voluptuous as v
17 16
 
18 17
 from nodepool.driver import ConfigPool
@@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig
20 19
 
21 20
 
22 21
 class TestPool(ConfigPool):
23
-    pass
22
+    def load(self, pool_config):
23
+        super().load(pool_config)
24
+        self.name = pool_config['name']
25
+        self.labels = pool_config['labels']
24 26
 
25 27
 
26 28
 class TestConfig(ProviderConfig):
@@ -43,18 +45,19 @@ class TestConfig(ProviderConfig):
43 45
         self.labels = set()
44 46
         for pool in self.provider.get('pools', []):
45 47
             testpool = TestPool()
46
-            testpool.name = pool['name']
48
+            testpool.load(pool)
47 49
             testpool.provider = self
48
-            testpool.max_servers = pool.get('max-servers', math.inf)
49
-            testpool.labels = pool['labels']
50 50
             for label in pool['labels']:
51 51
                 self.labels.add(label)
52 52
                 newconfig.labels[label].pools.append(testpool)
53 53
             self.pools[pool['name']] = testpool
54 54
 
55 55
     def getSchema(self):
56
-        pool = {'name': str,
57
-                'labels': [str]}
56
+        pool = ConfigPool.getCommonSchemaDict()
57
+        pool.update({
58
+            'name': str,
59
+            'labels': [str]
60
+        })
58 61
         return v.Schema({'pools': [pool]})
59 62
 
60 63
     def getSupportedLabels(self, pool_name=None):

+ 11
- 5
nodepool/tests/unit/test_config_comparisons.py View File

@@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool
28 28
 from nodepool.driver.static.config import StaticProviderConfig
29 29
 
30 30
 
31
+class TempConfigPool(ConfigPool):
32
+    def load(self):
33
+        pass
34
+
35
+
31 36
 class TestConfigComparisons(tests.BaseTestCase):
32 37
 
33 38
     def test_ConfigPool(self):
34
-        a = ConfigPool()
35
-        b = ConfigPool()
39
+
40
+        a = TempConfigPool()
41
+        b = TempConfigPool()
36 42
         self.assertEqual(a, b)
37 43
         a.max_servers = 5
38 44
         self.assertNotEqual(a, b)
@@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase):
94 100
         a.max_servers = 5
95 101
         self.assertNotEqual(a, b)
96 102
 
97
-        c = ConfigPool()
103
+        c = TempConfigPool()
98 104
         d = ProviderPool()
99
-        self.assertNotEqual(c, d)
105
+        self.assertNotEqual(d, c)
100 106
 
101 107
     def test_OpenStackProviderConfig(self):
102 108
         provider = {'name': 'foo'}
@@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase):
114 120
         # intentionally change an attribute of the base class
115 121
         a.max_servers = 5
116 122
         self.assertNotEqual(a, b)
117
-        c = ConfigPool()
123
+        c = TempConfigPool()
118 124
         self.assertNotEqual(b, c)
119 125
 
120 126
     def test_StaticProviderConfig(self):

Loading…
Cancel
Save