From 97181b9a32896e22feaa99b7ddcf01c41382a320 Mon Sep 17 00:00:00 2001 From: "James E. Blair" Date: Mon, 8 Mar 2021 19:48:45 -0800 Subject: [PATCH] Add azure state machine driver Change-Id: Icf1e2aa0fe36410eee1a0b0f178c3472a2cd2a4c --- nodepool/driver/azurestate/__init__.py | 25 ++ nodepool/driver/azurestate/adapter.py | 418 +++++++++++++++++++++++++ nodepool/driver/azurestate/azul.py | 269 ++++++++++++++++ nodepool/driver/azurestate/config.py | 225 +++++++++++++ nodepool/driver/statemachine.py | 21 +- 5 files changed, 952 insertions(+), 6 deletions(-) create mode 100644 nodepool/driver/azurestate/__init__.py create mode 100644 nodepool/driver/azurestate/adapter.py create mode 100644 nodepool/driver/azurestate/azul.py create mode 100644 nodepool/driver/azurestate/config.py diff --git a/nodepool/driver/azurestate/__init__.py b/nodepool/driver/azurestate/__init__.py new file mode 100644 index 000000000..6fb6353b6 --- /dev/null +++ b/nodepool/driver/azurestate/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2021 Acme Gating, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from nodepool.driver.statemachine import StateMachineDriver +from nodepool.driver.azurestate.config import AzureProviderConfig +from nodepool.driver.azurestate.adapter import AzureAdapter + + +class AzureDriver(StateMachineDriver): + def getProviderConfig(self, provider): + return AzureProviderConfig(self, provider) + + def getAdapter(self, provider_config): + return AzureAdapter(provider_config) diff --git a/nodepool/driver/azurestate/adapter.py b/nodepool/driver/azurestate/adapter.py new file mode 100644 index 000000000..ebb9991b8 --- /dev/null +++ b/nodepool/driver/azurestate/adapter.py @@ -0,0 +1,418 @@ +# Copyright 2021 Acme Gating, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import math +import logging +import json + +import cachetools.func + +from nodepool.driver.utils import QuotaInformation, RateLimiter +from nodepool.driver import statemachine +from . import azul + + +class AzureInstance(statemachine.Instance): + def __init__(self, vm, nic=None, pip=None): + self.external_id = vm['name'] + self.metadata = vm['tags'] or {} + self.private_ipv4 = None + self.public_ipv4 = None + self.public_ipv6 = None + + if nic: + for ip_config_data in nic['properties']['ipConfigurations']: + ip_config_prop = ip_config_data['properties'] + if ip_config_prop['privateIPAddressVersion'] == 'IPv4': + self.private_ipv4 = ip_config_prop['privateIPAddress'] + # public_ipv6 + + if pip: + self.public_ipv4 = pip['properties'].get('ipAddress') + + self.interface_ip = self.public_ipv4 or self.private_ipv4 + self.region = vm['location'] + self.az = '' + + +class AzureDeleteStateMachine(statemachine.StateMachine): + VM_DELETING = 'deleting vm' + NIC_DELETING = 'deleting nic' + PIP_DELETING = 'deleting pip' + DISK_DELETING = 'deleting disk' + COMPLETE = 'complete' + + def __init__(self, adapter, external_id): + super().__init__() + self.adapter = adapter + self.external_id = external_id + + def advance(self): + if self.state == self.START: + self.vm = self.adapter._deleteVirtualMachine( + self.external_id) + self.state = self.VM_DELETING + + if self.state == self.VM_DELETING: + self.vm = self.adapter._refresh_delete(self.vm) + if self.vm is None: + self.nic = self.adapter._deleteNetworkInterface( + self.external_id + '-nic') + self.state = self.NIC_DELETING + + if self.state == self.NIC_DELETING: + self.nic = self.adapter._refresh_delete(self.nic) + if self.nic is None: + self.pip = self.adapter._deletePublicIPAddress( + self.external_id + '-nic-pip') + self.state = self.PIP_DELETING + + if self.state == self.PIP_DELETING: + self.pip = self.adapter._refresh_delete(self.pip) + if self.pip is None: + self.disks = [] + for disk in self.adapter._listDisks(): + if disk['tags'] is not None and \ + disk['tags'].get('nodepool_id') == self.external_id: + disk = self.adapter._deleteDisk(disk['name']) + self.disks.append(disk) + self.state = self.DISK_DELETING + + if self.state == self.DISK_DELETING: + all_deleted = True + for disk in self.disks: + disk = self.adapter._refresh_delete(disk) + if disk: + all_deleted = False + if all_deleted: + self.state = self.COMPLETE + self.complete = True + + +class AzureCreateStateMachine(statemachine.StateMachine): + PIP_CREATING = 'creating pip' + NIC_CREATING = 'creating nic' + VM_CREATING = 'creating vm' + NIC_QUERY = 'querying nic' + PIP_QUERY = 'querying pip' + COMPLETE = 'complete' + + def __init__(self, adapter, hostname, label, metadata, retries): + super().__init__() + self.adapter = adapter + self.retries = retries + self.metadata = metadata + self.tags = label.tags or {} + self.tags.update(metadata) + self.hostname = hostname + self.label = label + self.pip = None + self.nic = None + self.vm = None + + def advance(self): + if self.state == self.START: + self.pip = self.adapter._createPublicIPAddress( + self.tags, self.hostname) + self.state = self.PIP_CREATING + self.external_id = self.hostname + + if self.state == self.PIP_CREATING: + self.pip = self.adapter._refresh(self.pip) + if self.adapter._succeeded(self.pip): + self.nic = self.adapter._createNetworkInterface( + self.tags, self.hostname, self.pip) + self.state = self.NIC_CREATING + else: + return + + if self.state == self.NIC_CREATING: + self.nic = self.adapter._refresh(self.nic) + if self.adapter._succeeded(self.nic): + self.vm = self.adapter._createVirtualMachine( + self.label, self.tags, self.hostname, self.nic) + self.state = self.VM_CREATING + else: + return + + if self.state == self.VM_CREATING: + self.vm = self.adapter._refresh(self.vm) + # if 404: + # increment retries + # state = self.NIC_CREATING + # if error: + # if retries too big: raise error + # delete vm + if self.adapter._succeeded(self.vm): + self.state = self.NIC_QUERY + else: + return + + if self.state == self.NIC_QUERY: + self.nic = self.adapter._refresh(self.nic, force=True) + for ip_config_data in self.nic['properties']['ipConfigurations']: + ip_config_prop = ip_config_data['properties'] + if ip_config_prop['privateIPAddressVersion'] == 'IPv4': + if 'privateIPAddress' in ip_config_prop: + self.state = self.PIP_QUERY + + if self.state == self.PIP_QUERY: + self.pip = self.adapter._refresh(self.pip, force=True) + if 'ipAddress' in self.pip['properties']: + self.state = self.COMPLETE + + if self.state == self.COMPLETE: + self.complete = True + return AzureInstance(self.vm, self.nic, self.pip) + + +class AzureAdapter(statemachine.Adapter): + log = logging.getLogger("nodepool.driver.azure.AzureAdapter") + + def __init__(self, provider_config): + self.provider = provider_config + self.resource_group = self.provider.resource_group + self.resource_group_location = self.provider.resource_group_location + self.rate_limiter = RateLimiter(self.provider.name, + self.provider.rate_limit) + with open(self.provider.auth_path) as f: + self.azul = azul.AzureCloud(json.load(f)) + + def getCreateStateMachine(self, hostname, label, metadata, retries): + return AzureCreateStateMachine( + self, hostname, label, metadata, retries) + + def getDeleteStateMachine(self, external_id): + return AzureDeleteStateMachine(self, external_id) + + def cleanupLeakedResources(self, known_nodes, metadata): + for vm in self._listVirtualMachines(): + node_id = self._metadataMatches(vm, metadata) + if (node_id and node_id not in known_nodes): + self.log.info(f"Deleting leaked vm: {vm['name']}") + self.azul.virtual_machines.delete( + self.resource_group, vm['name']) + for nic in self._listNetworkInterfaces(): + node_id = self._metadataMatches(nic, metadata) + if (node_id and node_id not in known_nodes): + self.log.info(f"Deleting leaked nic: {nic['name']}") + self.azul.network_interfaces.delete( + self.resource_group, nic['name']) + for pip in self._listPublicIPAddresses(): + node_id = self._metadataMatches(pip, metadata) + if (node_id and node_id not in known_nodes): + self.log.info(f"Deleting leaked pip: {pip['name']}") + self.azul.public_ip_addresses.delete( + self.resource_group, pip['name']) + for disk in self._listDisks(): + node_id = self._metadataMatches(disk, metadata) + if (node_id and node_id not in known_nodes): + self.log.info(f"Deleting leaked disk: {disk['name']}") + self.azul.disks.delete(self.resource_group, disk['name']) + + def listInstances(self): + for vm in self._listVirtualMachines(): + yield AzureInstance(vm) + + def getQuotaLimits(self): + return QuotaInformation(default=math.inf) + + def getQuotaForLabel(self, label): + return QuotaInformation(instances=1) + + # Local implementation below + + def _metadataMatches(self, obj, metadata): + if 'tags' not in obj: + return None + for k, v in metadata.items(): + if obj['tags'].get(k) != v: + return None + return obj['tags']['nodepool_node_id'] + + @staticmethod + def _succeeded(obj): + return obj['properties']['provisioningState'] == 'Succeeded' + + def _refresh(self, obj, force=False): + if self._succeeded(obj) and not force: + return obj + + if obj['type'] == 'Microsoft.Network/publicIPAddresses': + l = self._listPublicIPAddresses() + if obj['type'] == 'Microsoft.Network/networkInterfaces': + l = self._listNetworkInterfaces() + if obj['type'] == 'Microsoft.Compute/virtualMachines': + l = self._listVirtualMachines() + + for new_obj in l: + if new_obj['id'] == obj['id']: + return new_obj + return obj + + def _refresh_delete(self, obj): + if obj is None: + return obj + + if obj['type'] == 'Microsoft.Network/publicIPAddresses': + l = self._listPublicIPAddresses() + if obj['type'] == 'Microsoft.Network/networkInterfaces': + l = self._listNetworkInterfaces() + if obj['type'] == 'Microsoft.Compute/virtualMachines': + l = self._listVirtualMachines() + + for new_obj in l: + if new_obj['id'] == obj['id']: + return new_obj + return None + + @cachetools.func.ttl_cache(maxsize=1, ttl=10) + def _listPublicIPAddresses(self): + return self.azul.public_ip_addresses.list(self.resource_group) + + def _createPublicIPAddress(self, tags, hostname): + v4_params_create = { + 'location': self.provider.location, + 'tags': tags, + 'properties': { + 'publicIpAllocationMethod': 'dynamic', + }, + } + return self.azul.public_ip_addresses.create( + self.resource_group, + "%s-nic-pip" % hostname, + v4_params_create, + ) + + def _deletePublicIPAddress(self, name): + for pip in self._listPublicIPAddresses(): + if pip['name'] == name: + break + else: + return None + self.azul.public_ip_addresses.delete(self.resource_group, name) + return pip + + @cachetools.func.ttl_cache(maxsize=1, ttl=10) + def _listNetworkInterfaces(self): + return self.azul.network_interfaces.list(self.resource_group) + + def _createNetworkInterface(self, tags, hostname, pip): + nic_data = { + 'location': self.provider.location, + 'tags': tags, + 'properties': { + 'ipConfigurations': [{ + 'name': "nodepool-v4-ip-config", + 'properties': { + 'privateIpAddressVersion': 'IPv4', + 'subnet': { + 'id': self.provider.subnet_id + }, + 'publicIpAddress': { + 'id': pip['id'] + } + } + }] + } + } + + if self.provider.ipv6: + nic_data['properties']['ipConfigurations'].append({ + 'name': "nodepool-v6-ip-config", + 'properties': { + 'privateIpAddressVersion': 'IPv6', + 'subnet': { + 'id': self.provider.subnet_id + } + } + }) + + return self.azul.network_interfaces.create( + self.resource_group, + "%s-nic" % hostname, + nic_data + ) + + def _deleteNetworkInterface(self, name): + for nic in self._listNetworkInterfaces(): + if nic['name'] == name: + break + else: + return None + self.azul.network_interfaces.delete(self.resource_group, name) + return nic + + @cachetools.func.ttl_cache(maxsize=1, ttl=10) + def _listVirtualMachines(self): + return self.azul.virtual_machines.list(self.resource_group) + + def _createVirtualMachine(self, label, tags, hostname, nic): + return self.azul.virtual_machines.create( + self.resource_group, hostname, { + 'location': self.provider.location, + 'tags': tags, + 'properties': { + 'osProfile': { + 'computerName': hostname, + 'adminUsername': label.cloud_image.username, + 'linuxConfiguration': { + 'ssh': { + 'publicKeys': [{ + 'path': "/home/%s/.ssh/authorized_keys" % ( + label.cloud_image.username), + 'keyData': label.cloud_image.key, + }] + }, + "disablePasswordAuthentication": True, + } + }, + 'hardwareProfile': { + 'vmSize': label.hardware_profile["vm-size"] + }, + 'storageProfile': { + 'imageReference': label.cloud_image.image_reference + }, + 'networkProfile': { + 'networkInterfaces': [{ + 'id': nic['id'], + 'properties': { + 'primary': True, + } + }] + }, + }, + }) + + def _deleteVirtualMachine(self, name): + for vm in self._listVirtualMachines(): + if vm['name'] == name: + break + else: + return None + self.azul.virtual_machines.delete(self.resource_group, name) + return vm + + @cachetools.func.ttl_cache(maxsize=1, ttl=10) + def _listDisks(self): + return self.azul.disks.list(self.resource_group) + + def _deleteDisk(self, name): + for disk in self._listNetworkInterfaces(): + if disk['name'] == name: + break + else: + return None + self.azul.disks.delete(self.resource_group, name) + return disk diff --git a/nodepool/driver/azurestate/azul.py b/nodepool/driver/azurestate/azul.py new file mode 100644 index 000000000..ede321517 --- /dev/null +++ b/nodepool/driver/azurestate/azul.py @@ -0,0 +1,269 @@ +# Copyright 2021 Acme Gating, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import requests +import logging +import time + + +class AzureAuth(requests.auth.AuthBase): + AUTH_URL = "https://login.microsoftonline.com/{tenantId}/oauth2/token" + + def __init__(self, credential): + self.log = logging.getLogger("azul.auth") + self.credential = credential + self.token = None + self.expiration = time.time() + + def refresh(self): + if self.expiration - time.time() < 60: + self.log.debug('Refreshing authentication token') + url = self.AUTH_URL.format(**self.credential) + data = { + 'grant_type': 'client_credentials', + 'client_id': self.credential['clientId'], + 'client_secret': self.credential['clientSecret'], + 'resource': 'https://management.azure.com/', + } + r = requests.post(url, data) + ret = r.json() + self.token = ret['access_token'] + self.expiration = float(ret['expires_on']) + + def __call__(self, r): + self.refresh() + r.headers["authorization"] = "Bearer " + self.token + return r + + +class AzureError(Exception): + def __init__(self, status_code, message): + super().__init__(message) + self.status_code = status_code + + +class AzureNotFoundError(AzureError): + def __init__(self, status_code, message): + super().__init__(status_code, message) + + +class AzureResourceGroupsCRUD: + def __init__(self, cloud, version): + self.cloud = cloud + self.version = version + + def url(self, url, **args): + base_url = ( + 'https://management.azure.com/subscriptions/{subscriptionId}' + '/resourcegroups/') + url = base_url + url + '?api-version={apiVersion}' + args = args.copy() + args.update(self.cloud.credential) + args['apiVersion'] = self.version + return url.format(**args) + + def list(self): + url = self.url('') + return self.cloud.paginate(self.cloud.get(url)) + + def get(self, name): + url = self.url(name) + return self.cloud.get(url) + + def create(self, name, params): + url = self.url(name) + return self.cloud.put(url, params) + + def delete(self, name): + url = self.url(name) + return self.cloud.delete(url) + + +class AzureCRUD: + def __init__(self, cloud, resource, version): + self.cloud = cloud + self.resource = resource + self.version = version + + def url(self, url, **args): + base_url = ( + 'https://management.azure.com/subscriptions/{subscriptionId}' + '/resourceGroups/{resourceGroupName}/providers/') + url = base_url + url + '?api-version={apiVersion}' + args = args.copy() + args.update(self.cloud.credential) + args['apiVersion'] = self.version + return url.format(**args) + + def id_url(self, url, **args): + base_url = 'https://management.azure.com' + url = base_url + url + '?api-version={apiVersion}' + args = args.copy() + args['apiVersion'] = self.version + return url.format(**args) + + def list(self, resource_group_name): + url = self.url( + self.resource, + resourceGroupName=resource_group_name, + ) + return self.cloud.paginate(self.cloud.get(url)) + + def get_by_id(self, resource_id): + url = self.id_url(resource_id) + return self.cloud.get(url) + + def get(self, resource_group_name, name): + url = self.url( + '{_resource}/{_resourceName}', + _resource=self.resource, + _resourceName=name, + resourceGroupName=resource_group_name, + ) + return self.cloud.get(url) + + def create(self, resource_group_name, name, params): + url = self.url( + '{_resource}/{_resourceName}', + _resource=self.resource, + _resourceName=name, + resourceGroupName=resource_group_name, + ) + return self.cloud.put(url, params) + + def delete(self, resource_group_name, name): + url = self.url( + '{_resource}/{_resourceName}', + _resource=self.resource, + _resourceName=name, + resourceGroupName=resource_group_name, + ) + return self.cloud.delete(url) + + +class AzureDictResponse(dict): + def __init__(self, response, *args): + super().__init__(*args) + self.response = response + self.last_retry = time.time() + + +class AzureListResponse(list): + def __init__(self, response, *args): + super().__init__(*args) + self.response = response + self.last_retry = time.time() + + +class AzureCloud: + TIMEOUT = 60 + + def __init__(self, credential): + self.credential = credential + self.session = requests.Session() + self.log = logging.getLogger("azul") + self.auth = AzureAuth(credential) + self.network_interfaces = AzureCRUD( + self, + 'Microsoft.Network/networkInterfaces', + '2020-07-01') + self.public_ip_addresses = AzureCRUD( + self, + 'Microsoft.Network/publicIPAddresses', + '2020-07-01') + self.virtual_machines = AzureCRUD( + self, + 'Microsoft.Compute/virtualMachines', + '2020-12-01') + self.disks = AzureCRUD( + self, + 'Microsoft.Compute/disks', + '2020-06-30') + self.resource_groups = AzureResourceGroupsCRUD( + self, + '2020-06-01') + + def get(self, url, codes=[200]): + return self.request('GET', url, None, codes) + + def put(self, url, data, codes=[200, 201]): + return self.request('PUT', url, data, codes) + + def delete(self, url, codes=[200, 201, 202, 204]): + return self.request('DELETE', url, None, codes) + + def request(self, method, url, data, codes): + self.log.debug('%s: %s %s' % (method, url, data)) + response = self.session.request( + method, url, json=data, + auth=self.auth, timeout=self.TIMEOUT, + headers={'Accept': 'application/json', + 'Accept-Encoding': 'gzip'}) + + self.log.debug("Received headers: %s", response.headers) + if response.status_code in codes: + if len(response.text): + self.log.debug("Received: %s", response.text) + ret_data = response.json() + if isinstance(ret_data, list): + return AzureListResponse(response, ret_data) + else: + return AzureDictResponse(response, ret_data) + self.log.debug("Empty response") + return AzureDictResponse(response, {}) + err = response.json() + self.log.error(response.text) + if response.status_code == 404: + raise AzureNotFoundError( + response.status_code, err['error']['message']) + else: + raise AzureError(response.status_code, err['error']['message']) + + def paginate(self, data): + ret = data['value'] + while 'nextLink' in data: + data = self.get(data['nextLink']) + ret += data['value'] + return ret + + def check_async_operation(self, response): + resp = response.response + location = resp.headers.get( + 'Azure-AsyncOperation', + resp.headers.get('Location', None)) + if not location: + self.log.debug("No async operation found") + return None + remain = (response.last_retry + + float(resp.headers.get('Retry-After', 2))) - time.time() + self.log.debug("remain time %s", remain) + if remain > 0: + time.sleep(remain) + response.last_retry = time.time() + return self.get(location) + + def wait_for_async_operation(self, response, timeout=600): + start = time.time() + while True: + if time.time() - start > timeout: + raise Exception("Timeout waiting for async operation") + ret = self.check_async_operation(response) + if ret is None: + return + if ret['status'] == 'InProgress': + continue + if ret['status'] == 'Succeeded': + return + raise Exception("Unhandled async operation result: %s", + ret['status']) diff --git a/nodepool/driver/azurestate/config.py b/nodepool/driver/azurestate/config.py new file mode 100644 index 000000000..2dd77c88b --- /dev/null +++ b/nodepool/driver/azurestate/config.py @@ -0,0 +1,225 @@ +# Copyright 2018 Red Hat +# Copyright 2021 Acme Gating, LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import voluptuous as v +import os + +from nodepool.driver import ConfigPool +from nodepool.driver import ConfigValue +from nodepool.driver import ProviderConfig + + +class ProviderCloudImage(ConfigValue): + def __init__(self): + self.name = None + self.image_id = None + self.username = None + self.key = None + self.python_path = None + self.connection_type = None + self.connection_port = None + + def __eq__(self, other): + if isinstance(other, ProviderCloudImage): + return (self.name == other.name + and self.image_id == other.image_id + and self.username == other.username + and self.key == other.key + and self.python_path == other.python_path + and self.connection_type == other.connection_type + and self.connection_port == other.connection_port) + return False + + def __repr__(self): + return "" % self.name + + @property + def external_name(self): + '''Human readable version of external.''' + return self.image_id or self.name + + +class AzureLabel(ConfigValue): + def __eq__(self, other): + if ( # other.username != self.username or + # other.imageReference != self.imageReference or + other.hardware_profile != self.hardware_profile): + return False + return True + + +class AzurePool(ConfigPool): + def __eq__(self, other): + if other.labels != self.labels: + return False + return True + + def __repr__(self): + return "" % self.name + + def load(self, pool_config): + pass + + +class AzureProviderConfig(ProviderConfig): + def __init__(self, driver, provider): + self._pools = {} + self.driver_object = driver + self.rate_limit = None + self.launch_retries = None + super().__init__(provider) + + def __eq__(self, other): + if (other.location != self.location or + other.pools != self.pools): + return False + return True + + @property + def pools(self): + return self._pools + + @property + def manage_images(self): + return False + + @staticmethod + def reset(): + pass + + def load(self, config): + self.rate_limit = self.provider.get('rate-limit', 1) + self.launch_retries = self.provider.get('launch-retries', 3) + self.boot_timeout = self.provider.get('boot-timeout', 60) + + self.zuul_public_key = self.provider['zuul-public-key'] + self.location = self.provider['location'] + self.subnet_id = self.provider['subnet-id'] + self.ipv6 = self.provider.get('ipv6', False) + self.resource_group = self.provider['resource-group'] + self.resource_group_location = self.provider['resource-group-location'] + self.auth_path = self.provider.get( + 'auth-path', os.getenv('AZURE_AUTH_LOCATION', None)) + + default_port_mapping = { + 'ssh': 22, + 'winrm': 5986, + } + + self.cloud_images = {} + for image in self.provider['cloud-images']: + i = ProviderCloudImage() + i.name = image['name'] + i.username = image['username'] + # i.key = image['key'] + i.key = self.zuul_public_key + i.image_reference = image['image-reference'] + i.connection_type = image.get('connection-type', 'ssh') + i.connection_port = image.get( + 'connection-port', + default_port_mapping.get(i.connection_type, 22)) + self.cloud_images[i.name] = i + + for pool in self.provider.get('pools', []): + pp = AzurePool() + pp.name = pool['name'] + pp.provider = self + pp.max_servers = pool['max-servers'] + pp.use_internal_ip = bool(pool.get('use-internal-ip', False)) + pp.host_key_checking = bool(pool.get( + 'host-key-checking', True)) + self._pools[pp.name] = pp + pp.labels = {} + + for label in pool.get('labels', []): + pl = AzureLabel() + pl.name = label['name'] + pl.pool = pp + pp.labels[pl.name] = pl + + cloud_image_name = label['cloud-image'] + if cloud_image_name: + cloud_image = self.cloud_images.get( + cloud_image_name, None) + if not cloud_image: + raise ValueError( + "cloud-image %s does not exist in provider %s" + " but is referenced in label %s" % + (cloud_image_name, self.name, pl.name)) + # pl.imageReference = cloud_image['image-reference'] + pl.cloud_image = cloud_image + # pl.username = cloud_image.get('username', 'zuul') + else: + # pl.imageReference = None + pl.cloud_image = None + # pl.username = 'zuul' + + pl.hardware_profile = label['hardware-profile'] + + config.labels[label['name']].pools.append(pp) + pl.tags = label.get('tags', {}) + + def getSchema(self): + + azure_image_reference = { + v.Required('sku'): str, + v.Required('publisher'): str, + v.Required('version'): str, + v.Required('offer'): str, + } + + azure_hardware_profile = { + v.Required('vm-size'): str, + } + + provider_cloud_images = { + v.Required('name'): str, + 'username': str, + v.Required('image-reference'): azure_image_reference, + } + + azure_label = { + v.Required('name'): str, + v.Required('hardware-profile'): azure_hardware_profile, + v.Required('cloud-image'): str, + v.Optional('tags'): dict, + } + pool = ConfigPool.getCommonSchemaDict() + pool.update({ + v.Required('name'): str, + v.Required('labels'): [azure_label], + }) + + provider = ProviderConfig.getCommonSchemaDict() + provider.update({ + v.Required('zuul-public-key'): str, + v.Required('pools'): [pool], + v.Required('location'): str, + v.Required('resource-group'): str, + v.Required('resource-group-location'): str, + v.Required('subnet-id'): str, + v.Required('cloud-images'): [provider_cloud_images], + v.Optional('auth-path'): str, + }) + return v.Schema(provider) + + def getSupportedLabels(self, pool_name=None): + labels = set() + for pool in self._pools.values(): + if not pool_name or (pool.name == pool_name): + labels.update(pool.labels.keys()) + return labels diff --git a/nodepool/driver/statemachine.py b/nodepool/driver/statemachine.py index 762c6a0d8..51eda413b 100644 --- a/nodepool/driver/statemachine.py +++ b/nodepool/driver/statemachine.py @@ -157,8 +157,9 @@ class StateMachineNodeLauncher(stats.StatsReporter): raise Exception("Driver implementation error: state " "machine must produce external ID " "after first advancement") - self.updateNodeFromInstance(instance) - if state_machine.complete: + node.external_id = state_machine.external_id + self.zk.storeNode(node) + if state_machine.complete and not self.keyscan_future: self.log.debug("Submitting keyscan request") self.updateNodeFromInstance(instance) future = self.manager.keyscan_worker.submit( @@ -252,14 +253,16 @@ class StateMachineNodeDeleter: if node.external_id: state_machine.advance() self.log.debug(f"State machine for {node.id} at " - "{state_machine.state}") + f"{state_machine.state}") + else: + self.state_machine.complete = True if not self.state_machine.complete: return except exceptions.NotFound: self.log.info(f"Instance {node.external_id} not found in " - "provider {node.provider}") + f"provider {node.provider}") except Exception: self.log.exception("Exception deleting instance " f"{node.external_id} from {node.provider}:") @@ -423,15 +426,21 @@ class StateMachineProvider(Provider, QuotaSupport): self.state_machine_thread.start() def stop(self): + self.log.debug("Stopping") self.running = False + if self.state_machine_thread: + while self.launchers or self.deleters: + time.sleep(1) + self.running = False if self.keyscan_worker: self.keyscan_worker.shutdown() - if self.state_machine_thread: - self.running = False + self.log.debug("Stopped") def join(self): + self.log.debug("Joining") if self.state_machine_thread: self.state_machine_thread.join() + self.log.debug("Joined") def _runStateMachines(self): while self.running: