Browse Source

Merge "Extract out common config parsing for ConfigPool"

tags/3.5.0
Zuul 4 months ago
parent
commit
1fe5fb60c5

+ 39
- 1
nodepool/driver/__init__.py View File

@@ -824,7 +824,11 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
824 824
         return not self.__eq__(other)
825 825
 
826 826
 
827
-class ConfigPool(ConfigValue):
827
+class ConfigPool(ConfigValue, metaclass=abc.ABCMeta):
828
+    '''
829
+    Base class for a single pool as defined in the configuration file.
830
+    '''
831
+
828 832
     def __init__(self):
829 833
         self.labels = {}
830 834
         self.max_servers = math.inf
@@ -837,6 +841,40 @@ class ConfigPool(ConfigValue):
837 841
                     self.node_attributes == other.node_attributes)
838 842
         return False
839 843
 
844
+    @classmethod
845
+    def getCommonSchemaDict(self):
846
+        '''
847
+        Return the schema dict for common pool attributes.
848
+
849
+        When a driver validates its own configuration schema, it should call
850
+        this class method to get and include the common pool attributes in
851
+        the schema.
852
+
853
+        The `labels` attribute, though common, can vary its type across
854
+        drivers so it is not returned in the schema.
855
+        '''
856
+        return {
857
+            'max-servers': int,
858
+            'node-attributes': dict,
859
+        }
860
+
861
+    @abc.abstractmethod
862
+    def load(self, pool_config):
863
+        '''
864
+        Load pool config options from the parsed configuration file.
865
+
866
+        Subclasses are expected to call the parent method so that common
867
+        configuration values are loaded properly.
868
+
869
+        Although `labels` is a common attribute, each driver may
870
+        define it differently, so we cannot parse that attribute here.
871
+
872
+        :param dict pool_config: A single pool config section from which we
873
+            will load the values.
874
+        '''
875
+        self.max_servers = pool_config.get('max-servers', math.inf)
876
+        self.node_attributes = pool_config.get('node-attributes')
877
+
840 878
 
841 879
 class DriverConfig(ConfigValue):
842 880
     def __init__(self):

+ 18
- 13
nodepool/driver/kubernetes/config.py View File

@@ -45,6 +45,20 @@ class KubernetesPool(ConfigPool):
45 45
     def __repr__(self):
46 46
         return "<KubernetesPool %s>" % self.name
47 47
 
48
+    def load(self, pool_config, full_config):
49
+        super().load(pool_config)
50
+        self.name = pool_config['name']
51
+        self.labels = {}
52
+        for label in pool_config.get('labels', []):
53
+            pl = KubernetesLabel()
54
+            pl.name = label['name']
55
+            pl.type = label['type']
56
+            pl.image = label.get('image')
57
+            pl.image_pull = label.get('image-pull', 'IfNotPresent')
58
+            pl.pool = self
59
+            self.labels[pl.name] = pl
60
+            full_config.labels[label['name']].pools.append(self)
61
+
48 62
 
49 63
 class KubernetesProviderConfig(ProviderConfig):
50 64
     def __init__(self, driver, provider):
@@ -72,19 +86,9 @@ class KubernetesProviderConfig(ProviderConfig):
72 86
         self.context = self.provider['context']
73 87
         for pool in self.provider.get('pools', []):
74 88
             pp = KubernetesPool()
75
-            pp.name = pool['name']
89
+            pp.load(pool, config)
76 90
             pp.provider = self
77 91
             self.pools[pp.name] = pp
78
-            pp.labels = {}
79
-            for label in pool.get('labels', []):
80
-                pl = KubernetesLabel()
81
-                pl.name = label['name']
82
-                pl.type = label['type']
83
-                pl.image = label.get('image')
84
-                pl.image_pull = label.get('image-pull', 'IfNotPresent')
85
-                pl.pool = pp
86
-                pp.labels[pl.name] = pl
87
-                config.labels[label['name']].pools.append(pp)
88 92
 
89 93
     def getSchema(self):
90 94
         k8s_label = {
@@ -94,10 +98,11 @@ class KubernetesProviderConfig(ProviderConfig):
94 98
             'image-pull': str,
95 99
         }
96 100
 
97
-        pool = {
101
+        pool = ConfigPool.getCommonSchemaDict()
102
+        pool.update({
98 103
             v.Required('name'): str,
99 104
             v.Required('labels'): [k8s_label],
100
-        }
105
+        })
101 106
 
102 107
         provider = {
103 108
             v.Required('pools'): [pool],

+ 62
- 50
nodepool/driver/openstack/config.py View File

@@ -149,6 +149,64 @@ class ProviderPool(ConfigPool):
149 149
     def __repr__(self):
150 150
         return "<ProviderPool %s>" % self.name
151 151
 
152
+    def load(self, pool_config, full_config, provider):
153
+        '''
154
+        Load pool configuration options.
155
+
156
+        :param dict pool_config: A single pool config section from which we
157
+            will load the values.
158
+        :param dict full_config: The full nodepool config.
159
+        :param OpenStackProviderConfig: The calling provider object.
160
+        '''
161
+        super().load(pool_config)
162
+
163
+        self.provider = provider
164
+        self.name = pool_config['name']
165
+        self.max_cores = pool_config.get('max-cores', math.inf)
166
+        self.max_ram = pool_config.get('max-ram', math.inf)
167
+        self.ignore_provider_quota = pool_config.get('ignore-provider-quota',
168
+                                                     False)
169
+        self.azs = pool_config.get('availability-zones')
170
+        self.networks = pool_config.get('networks', [])
171
+        self.security_groups = pool_config.get('security-groups', [])
172
+        self.auto_floating_ip = bool(pool_config.get('auto-floating-ip', True))
173
+        self.host_key_checking = bool(pool_config.get('host-key-checking',
174
+                                                      True))
175
+
176
+        for label in pool_config.get('labels', []):
177
+            pl = ProviderLabel()
178
+            pl.name = label['name']
179
+            pl.pool = self
180
+            self.labels[pl.name] = pl
181
+            diskimage = label.get('diskimage', None)
182
+            if diskimage:
183
+                pl.diskimage = full_config.diskimages[diskimage]
184
+            else:
185
+                pl.diskimage = None
186
+            cloud_image_name = label.get('cloud-image', None)
187
+            if cloud_image_name:
188
+                cloud_image = provider.cloud_images.get(cloud_image_name, None)
189
+                if not cloud_image:
190
+                    raise ValueError(
191
+                        "cloud-image %s does not exist in provider %s"
192
+                        " but is referenced in label %s" %
193
+                        (cloud_image_name, self.name, pl.name))
194
+            else:
195
+                cloud_image = None
196
+            pl.cloud_image = cloud_image
197
+            pl.min_ram = label.get('min-ram', 0)
198
+            pl.flavor_name = label.get('flavor-name', None)
199
+            pl.key_name = label.get('key-name')
200
+            pl.console_log = label.get('console-log', False)
201
+            pl.boot_from_volume = bool(label.get('boot-from-volume',
202
+                                                 False))
203
+            pl.volume_size = label.get('volume-size', 50)
204
+            pl.instance_properties = label.get('instance-properties',
205
+                                               None)
206
+
207
+            top_label = full_config.labels[pl.name]
208
+            top_label.pools.append(self)
209
+
152 210
 
153 211
 class OpenStackProviderConfig(ProviderConfig):
154 212
     def __init__(self, driver, provider):
@@ -263,53 +321,8 @@ class OpenStackProviderConfig(ProviderConfig):
263 321
 
264 322
         for pool in self.provider.get('pools', []):
265 323
             pp = ProviderPool()
266
-            pp.name = pool['name']
267
-            pp.provider = self
324
+            pp.load(pool, config, self)
268 325
             self.pools[pp.name] = pp
269
-            pp.max_cores = pool.get('max-cores', math.inf)
270
-            pp.max_servers = pool.get('max-servers', math.inf)
271
-            pp.max_ram = pool.get('max-ram', math.inf)
272
-            pp.ignore_provider_quota = pool.get('ignore-provider-quota', False)
273
-            pp.azs = pool.get('availability-zones')
274
-            pp.networks = pool.get('networks', [])
275
-            pp.security_groups = pool.get('security-groups', [])
276
-            pp.auto_floating_ip = bool(pool.get('auto-floating-ip', True))
277
-            pp.host_key_checking = bool(pool.get('host-key-checking', True))
278
-            pp.node_attributes = pool.get('node-attributes')
279
-
280
-            for label in pool.get('labels', []):
281
-                pl = ProviderLabel()
282
-                pl.name = label['name']
283
-                pl.pool = pp
284
-                pp.labels[pl.name] = pl
285
-                diskimage = label.get('diskimage', None)
286
-                if diskimage:
287
-                    pl.diskimage = config.diskimages[diskimage]
288
-                else:
289
-                    pl.diskimage = None
290
-                cloud_image_name = label.get('cloud-image', None)
291
-                if cloud_image_name:
292
-                    cloud_image = self.cloud_images.get(cloud_image_name, None)
293
-                    if not cloud_image:
294
-                        raise ValueError(
295
-                            "cloud-image %s does not exist in provider %s"
296
-                            " but is referenced in label %s" %
297
-                            (cloud_image_name, self.name, pl.name))
298
-                else:
299
-                    cloud_image = None
300
-                pl.cloud_image = cloud_image
301
-                pl.min_ram = label.get('min-ram', 0)
302
-                pl.flavor_name = label.get('flavor-name', None)
303
-                pl.key_name = label.get('key-name')
304
-                pl.console_log = label.get('console-log', False)
305
-                pl.boot_from_volume = bool(label.get('boot-from-volume',
306
-                                                     False))
307
-                pl.volume_size = label.get('volume-size', 50)
308
-                pl.instance_properties = label.get('instance-properties',
309
-                                                   None)
310
-
311
-                top_label = config.labels[pl.name]
312
-                top_label.pools.append(pp)
313 326
 
314 327
     def getSchema(self):
315 328
         provider_diskimage = {
@@ -358,20 +371,19 @@ class OpenStackProviderConfig(ProviderConfig):
358 371
                            v.Any(label_min_ram, label_flavor_name),
359 372
                            v.Any(label_diskimage, label_cloud_image))
360 373
 
361
-        pool = {
374
+        pool = ConfigPool.getCommonSchemaDict()
375
+        pool.update({
362 376
             'name': str,
363 377
             'networks': [str],
364 378
             'auto-floating-ip': bool,
365 379
             'host-key-checking': bool,
366 380
             'ignore-provider-quota': bool,
367 381
             'max-cores': int,
368
-            'max-servers': int,
369 382
             'max-ram': int,
370 383
             'labels': [pool_label],
371
-            'node-attributes': dict,
372 384
             'availability-zones': [str],
373 385
             'security-groups': [str]
374
-        }
386
+        })
375 387
 
376 388
         return v.Schema({
377 389
             'region-name': str,

+ 31
- 26
nodepool/driver/static/config.py View File

@@ -41,6 +41,33 @@ class StaticPool(ConfigPool):
41 41
     def __repr__(self):
42 42
         return "<StaticPool %s>" % self.name
43 43
 
44
+    def load(self, pool_config, full_config):
45
+        super().load(pool_config)
46
+        self.name = pool_config['name']
47
+        # WARNING: This intentionally changes the type!
48
+        self.labels = set()
49
+        for node in pool_config.get('nodes', []):
50
+            self.nodes.append({
51
+                'name': node['name'],
52
+                'labels': as_list(node['labels']),
53
+                'host-key': as_list(node.get('host-key', [])),
54
+                'timeout': int(node.get('timeout', 5)),
55
+                # Read ssh-port values for backward compat, but prefer port
56
+                'connection-port': int(
57
+                    node.get('connection-port', node.get('ssh-port', 22))),
58
+                'connection-type': node.get('connection-type', 'ssh'),
59
+                'username': node.get('username', 'zuul'),
60
+                'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
61
+            })
62
+            if isinstance(node['labels'], str):
63
+                for label in node['labels'].split():
64
+                    self.labels.add(label)
65
+                    full_config.labels[label].pools.append(self)
66
+            elif isinstance(node['labels'], list):
67
+                for label in node['labels']:
68
+                    self.labels.add(label)
69
+                    full_config.labels[label].pools.append(self)
70
+
44 71
 
45 72
 class StaticProviderConfig(ProviderConfig):
46 73
     def __init__(self, *args, **kwargs):
@@ -65,32 +92,9 @@ class StaticProviderConfig(ProviderConfig):
65 92
     def load(self, config):
66 93
         for pool in self.provider.get('pools', []):
67 94
             pp = StaticPool()
68
-            pp.name = pool['name']
95
+            pp.load(pool, config)
69 96
             pp.provider = self
70 97
             self.pools[pp.name] = pp
71
-            # WARNING: This intentionally changes the type!
72
-            pp.labels = set()
73
-            for node in pool.get('nodes', []):
74
-                pp.nodes.append({
75
-                    'name': node['name'],
76
-                    'labels': as_list(node['labels']),
77
-                    'host-key': as_list(node.get('host-key', [])),
78
-                    'timeout': int(node.get('timeout', 5)),
79
-                    # Read ssh-port values for backward compat, but prefer port
80
-                    'connection-port': int(
81
-                        node.get('connection-port', node.get('ssh-port', 22))),
82
-                    'connection-type': node.get('connection-type', 'ssh'),
83
-                    'username': node.get('username', 'zuul'),
84
-                    'max-parallel-jobs': int(node.get('max-parallel-jobs', 1)),
85
-                })
86
-                if isinstance(node['labels'], str):
87
-                    for label in node['labels'].split():
88
-                        pp.labels.add(label)
89
-                        config.labels[label].pools.append(pp)
90
-                elif isinstance(node['labels'], list):
91
-                    for label in node['labels']:
92
-                        pp.labels.add(label)
93
-                        config.labels[label].pools.append(pp)
94 98
 
95 99
     def getSchema(self):
96 100
         pool_node = {
@@ -103,10 +107,11 @@ class StaticProviderConfig(ProviderConfig):
103 107
             'connection-type': str,
104 108
             'max-parallel-jobs': int,
105 109
         }
106
-        pool = {
110
+        pool = ConfigPool.getCommonSchemaDict()
111
+        pool.update({
107 112
             'name': str,
108 113
             'nodes': [pool_node],
109
-        }
114
+        })
110 115
         return v.Schema({'pools': [pool]})
111 116
 
112 117
     def getSupportedLabels(self, pool_name=None):

+ 10
- 7
nodepool/driver/test/config.py View File

@@ -12,7 +12,6 @@
12 12
 # License for the specific language governing permissions and limitations
13 13
 # under the License.
14 14
 
15
-import math
16 15
 import voluptuous as v
17 16
 
18 17
 from nodepool.driver import ConfigPool
@@ -20,7 +19,10 @@ from nodepool.driver import ProviderConfig
20 19
 
21 20
 
22 21
 class TestPool(ConfigPool):
23
-    pass
22
+    def load(self, pool_config):
23
+        super().load(pool_config)
24
+        self.name = pool_config['name']
25
+        self.labels = pool_config['labels']
24 26
 
25 27
 
26 28
 class TestConfig(ProviderConfig):
@@ -43,18 +45,19 @@ class TestConfig(ProviderConfig):
43 45
         self.labels = set()
44 46
         for pool in self.provider.get('pools', []):
45 47
             testpool = TestPool()
46
-            testpool.name = pool['name']
48
+            testpool.load(pool)
47 49
             testpool.provider = self
48
-            testpool.max_servers = pool.get('max-servers', math.inf)
49
-            testpool.labels = pool['labels']
50 50
             for label in pool['labels']:
51 51
                 self.labels.add(label)
52 52
                 newconfig.labels[label].pools.append(testpool)
53 53
             self.pools[pool['name']] = testpool
54 54
 
55 55
     def getSchema(self):
56
-        pool = {'name': str,
57
-                'labels': [str]}
56
+        pool = ConfigPool.getCommonSchemaDict()
57
+        pool.update({
58
+            'name': str,
59
+            'labels': [str]
60
+        })
58 61
         return v.Schema({'pools': [pool]})
59 62
 
60 63
     def getSupportedLabels(self, pool_name=None):

+ 11
- 5
nodepool/tests/unit/test_config_comparisons.py View File

@@ -28,11 +28,17 @@ from nodepool.driver.static.config import StaticPool
28 28
 from nodepool.driver.static.config import StaticProviderConfig
29 29
 
30 30
 
31
+class TempConfigPool(ConfigPool):
32
+    def load(self):
33
+        pass
34
+
35
+
31 36
 class TestConfigComparisons(tests.BaseTestCase):
32 37
 
33 38
     def test_ConfigPool(self):
34
-        a = ConfigPool()
35
-        b = ConfigPool()
39
+
40
+        a = TempConfigPool()
41
+        b = TempConfigPool()
36 42
         self.assertEqual(a, b)
37 43
         a.max_servers = 5
38 44
         self.assertNotEqual(a, b)
@@ -94,9 +100,9 @@ class TestConfigComparisons(tests.BaseTestCase):
94 100
         a.max_servers = 5
95 101
         self.assertNotEqual(a, b)
96 102
 
97
-        c = ConfigPool()
103
+        c = TempConfigPool()
98 104
         d = ProviderPool()
99
-        self.assertNotEqual(c, d)
105
+        self.assertNotEqual(d, c)
100 106
 
101 107
     def test_OpenStackProviderConfig(self):
102 108
         provider = {'name': 'foo'}
@@ -114,7 +120,7 @@ class TestConfigComparisons(tests.BaseTestCase):
114 120
         # intentionally change an attribute of the base class
115 121
         a.max_servers = 5
116 122
         self.assertNotEqual(a, b)
117
-        c = ConfigPool()
123
+        c = TempConfigPool()
118 124
         self.assertNotEqual(b, c)
119 125
 
120 126
     def test_StaticProviderConfig(self):

Loading…
Cancel
Save