Add azure state machine driver
Change-Id: Icf1e2aa0fe36410eee1a0b0f178c3472a2cd2a4c
This commit is contained in:
		
							
								
								
									
										25
									
								
								nodepool/driver/azurestate/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								nodepool/driver/azurestate/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
# Copyright 2021 Acme Gating, LLC
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
			
		||||
# not use this file except in compliance with the License. You may obtain
 | 
			
		||||
# a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
# License for the specific language governing permissions and limitations
 | 
			
		||||
# under the License.
 | 
			
		||||
 | 
			
		||||
from nodepool.driver.statemachine import StateMachineDriver
 | 
			
		||||
from nodepool.driver.azurestate.config import AzureProviderConfig
 | 
			
		||||
from nodepool.driver.azurestate.adapter import AzureAdapter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureDriver(StateMachineDriver):
 | 
			
		||||
    def getProviderConfig(self, provider):
 | 
			
		||||
        return AzureProviderConfig(self, provider)
 | 
			
		||||
 | 
			
		||||
    def getAdapter(self, provider_config):
 | 
			
		||||
        return AzureAdapter(provider_config)
 | 
			
		||||
							
								
								
									
										418
									
								
								nodepool/driver/azurestate/adapter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										418
									
								
								nodepool/driver/azurestate/adapter.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,418 @@
 | 
			
		||||
# Copyright 2021 Acme Gating, LLC
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
			
		||||
# not use this file except in compliance with the License. You may obtain
 | 
			
		||||
# a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
# License for the specific language governing permissions and limitations
 | 
			
		||||
# under the License.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import logging
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
import cachetools.func
 | 
			
		||||
 | 
			
		||||
from nodepool.driver.utils import QuotaInformation, RateLimiter
 | 
			
		||||
from nodepool.driver import statemachine
 | 
			
		||||
from . import azul
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureInstance(statemachine.Instance):
 | 
			
		||||
    def __init__(self, vm, nic=None, pip=None):
 | 
			
		||||
        self.external_id = vm['name']
 | 
			
		||||
        self.metadata = vm['tags'] or {}
 | 
			
		||||
        self.private_ipv4 = None
 | 
			
		||||
        self.public_ipv4 = None
 | 
			
		||||
        self.public_ipv6 = None
 | 
			
		||||
 | 
			
		||||
        if nic:
 | 
			
		||||
            for ip_config_data in nic['properties']['ipConfigurations']:
 | 
			
		||||
                ip_config_prop = ip_config_data['properties']
 | 
			
		||||
                if ip_config_prop['privateIPAddressVersion'] == 'IPv4':
 | 
			
		||||
                    self.private_ipv4 = ip_config_prop['privateIPAddress']
 | 
			
		||||
        # public_ipv6
 | 
			
		||||
 | 
			
		||||
        if pip:
 | 
			
		||||
            self.public_ipv4 = pip['properties'].get('ipAddress')
 | 
			
		||||
 | 
			
		||||
        self.interface_ip = self.public_ipv4 or self.private_ipv4
 | 
			
		||||
        self.region = vm['location']
 | 
			
		||||
        self.az = ''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureDeleteStateMachine(statemachine.StateMachine):
 | 
			
		||||
    VM_DELETING = 'deleting vm'
 | 
			
		||||
    NIC_DELETING = 'deleting nic'
 | 
			
		||||
    PIP_DELETING = 'deleting pip'
 | 
			
		||||
    DISK_DELETING = 'deleting disk'
 | 
			
		||||
    COMPLETE = 'complete'
 | 
			
		||||
 | 
			
		||||
    def __init__(self, adapter, external_id):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.adapter = adapter
 | 
			
		||||
        self.external_id = external_id
 | 
			
		||||
 | 
			
		||||
    def advance(self):
 | 
			
		||||
        if self.state == self.START:
 | 
			
		||||
            self.vm = self.adapter._deleteVirtualMachine(
 | 
			
		||||
                self.external_id)
 | 
			
		||||
            self.state = self.VM_DELETING
 | 
			
		||||
 | 
			
		||||
        if self.state == self.VM_DELETING:
 | 
			
		||||
            self.vm = self.adapter._refresh_delete(self.vm)
 | 
			
		||||
            if self.vm is None:
 | 
			
		||||
                self.nic = self.adapter._deleteNetworkInterface(
 | 
			
		||||
                    self.external_id + '-nic')
 | 
			
		||||
                self.state = self.NIC_DELETING
 | 
			
		||||
 | 
			
		||||
        if self.state == self.NIC_DELETING:
 | 
			
		||||
            self.nic = self.adapter._refresh_delete(self.nic)
 | 
			
		||||
            if self.nic is None:
 | 
			
		||||
                self.pip = self.adapter._deletePublicIPAddress(
 | 
			
		||||
                    self.external_id + '-nic-pip')
 | 
			
		||||
                self.state = self.PIP_DELETING
 | 
			
		||||
 | 
			
		||||
        if self.state == self.PIP_DELETING:
 | 
			
		||||
            self.pip = self.adapter._refresh_delete(self.pip)
 | 
			
		||||
            if self.pip is None:
 | 
			
		||||
                self.disks = []
 | 
			
		||||
                for disk in self.adapter._listDisks():
 | 
			
		||||
                    if disk['tags'] is not None and \
 | 
			
		||||
                       disk['tags'].get('nodepool_id') == self.external_id:
 | 
			
		||||
                        disk = self.adapter._deleteDisk(disk['name'])
 | 
			
		||||
                        self.disks.append(disk)
 | 
			
		||||
                self.state = self.DISK_DELETING
 | 
			
		||||
 | 
			
		||||
        if self.state == self.DISK_DELETING:
 | 
			
		||||
            all_deleted = True
 | 
			
		||||
            for disk in self.disks:
 | 
			
		||||
                disk = self.adapter._refresh_delete(disk)
 | 
			
		||||
                if disk:
 | 
			
		||||
                    all_deleted = False
 | 
			
		||||
            if all_deleted:
 | 
			
		||||
                self.state = self.COMPLETE
 | 
			
		||||
                self.complete = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureCreateStateMachine(statemachine.StateMachine):
 | 
			
		||||
    PIP_CREATING = 'creating pip'
 | 
			
		||||
    NIC_CREATING = 'creating nic'
 | 
			
		||||
    VM_CREATING = 'creating vm'
 | 
			
		||||
    NIC_QUERY = 'querying nic'
 | 
			
		||||
    PIP_QUERY = 'querying pip'
 | 
			
		||||
    COMPLETE = 'complete'
 | 
			
		||||
 | 
			
		||||
    def __init__(self, adapter, hostname, label, metadata, retries):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.adapter = adapter
 | 
			
		||||
        self.retries = retries
 | 
			
		||||
        self.metadata = metadata
 | 
			
		||||
        self.tags = label.tags or {}
 | 
			
		||||
        self.tags.update(metadata)
 | 
			
		||||
        self.hostname = hostname
 | 
			
		||||
        self.label = label
 | 
			
		||||
        self.pip = None
 | 
			
		||||
        self.nic = None
 | 
			
		||||
        self.vm = None
 | 
			
		||||
 | 
			
		||||
    def advance(self):
 | 
			
		||||
        if self.state == self.START:
 | 
			
		||||
            self.pip = self.adapter._createPublicIPAddress(
 | 
			
		||||
                self.tags, self.hostname)
 | 
			
		||||
            self.state = self.PIP_CREATING
 | 
			
		||||
            self.external_id = self.hostname
 | 
			
		||||
 | 
			
		||||
        if self.state == self.PIP_CREATING:
 | 
			
		||||
            self.pip = self.adapter._refresh(self.pip)
 | 
			
		||||
            if self.adapter._succeeded(self.pip):
 | 
			
		||||
                self.nic = self.adapter._createNetworkInterface(
 | 
			
		||||
                    self.tags, self.hostname, self.pip)
 | 
			
		||||
                self.state = self.NIC_CREATING
 | 
			
		||||
            else:
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        if self.state == self.NIC_CREATING:
 | 
			
		||||
            self.nic = self.adapter._refresh(self.nic)
 | 
			
		||||
            if self.adapter._succeeded(self.nic):
 | 
			
		||||
                self.vm = self.adapter._createVirtualMachine(
 | 
			
		||||
                    self.label, self.tags, self.hostname, self.nic)
 | 
			
		||||
                self.state = self.VM_CREATING
 | 
			
		||||
            else:
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        if self.state == self.VM_CREATING:
 | 
			
		||||
            self.vm = self.adapter._refresh(self.vm)
 | 
			
		||||
            # if 404:
 | 
			
		||||
            #   increment retries
 | 
			
		||||
            #   state = self.NIC_CREATING
 | 
			
		||||
            # if error:
 | 
			
		||||
            #   if retries too big: raise error
 | 
			
		||||
            #   delete vm
 | 
			
		||||
            if self.adapter._succeeded(self.vm):
 | 
			
		||||
                self.state = self.NIC_QUERY
 | 
			
		||||
            else:
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        if self.state == self.NIC_QUERY:
 | 
			
		||||
            self.nic = self.adapter._refresh(self.nic, force=True)
 | 
			
		||||
            for ip_config_data in self.nic['properties']['ipConfigurations']:
 | 
			
		||||
                ip_config_prop = ip_config_data['properties']
 | 
			
		||||
                if ip_config_prop['privateIPAddressVersion'] == 'IPv4':
 | 
			
		||||
                    if 'privateIPAddress' in ip_config_prop:
 | 
			
		||||
                        self.state = self.PIP_QUERY
 | 
			
		||||
 | 
			
		||||
        if self.state == self.PIP_QUERY:
 | 
			
		||||
            self.pip = self.adapter._refresh(self.pip, force=True)
 | 
			
		||||
            if 'ipAddress' in self.pip['properties']:
 | 
			
		||||
                self.state = self.COMPLETE
 | 
			
		||||
 | 
			
		||||
        if self.state == self.COMPLETE:
 | 
			
		||||
            self.complete = True
 | 
			
		||||
            return AzureInstance(self.vm, self.nic, self.pip)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureAdapter(statemachine.Adapter):
 | 
			
		||||
    log = logging.getLogger("nodepool.driver.azure.AzureAdapter")
 | 
			
		||||
 | 
			
		||||
    def __init__(self, provider_config):
 | 
			
		||||
        self.provider = provider_config
 | 
			
		||||
        self.resource_group = self.provider.resource_group
 | 
			
		||||
        self.resource_group_location = self.provider.resource_group_location
 | 
			
		||||
        self.rate_limiter = RateLimiter(self.provider.name,
 | 
			
		||||
                                        self.provider.rate_limit)
 | 
			
		||||
        with open(self.provider.auth_path) as f:
 | 
			
		||||
            self.azul = azul.AzureCloud(json.load(f))
 | 
			
		||||
 | 
			
		||||
    def getCreateStateMachine(self, hostname, label, metadata, retries):
 | 
			
		||||
        return AzureCreateStateMachine(
 | 
			
		||||
            self, hostname, label, metadata, retries)
 | 
			
		||||
 | 
			
		||||
    def getDeleteStateMachine(self, external_id):
 | 
			
		||||
        return AzureDeleteStateMachine(self, external_id)
 | 
			
		||||
 | 
			
		||||
    def cleanupLeakedResources(self, known_nodes, metadata):
 | 
			
		||||
        for vm in self._listVirtualMachines():
 | 
			
		||||
            node_id = self._metadataMatches(vm, metadata)
 | 
			
		||||
            if (node_id and node_id not in known_nodes):
 | 
			
		||||
                self.log.info(f"Deleting leaked vm: {vm['name']}")
 | 
			
		||||
                self.azul.virtual_machines.delete(
 | 
			
		||||
                    self.resource_group, vm['name'])
 | 
			
		||||
        for nic in self._listNetworkInterfaces():
 | 
			
		||||
            node_id = self._metadataMatches(nic, metadata)
 | 
			
		||||
            if (node_id and node_id not in known_nodes):
 | 
			
		||||
                self.log.info(f"Deleting leaked nic: {nic['name']}")
 | 
			
		||||
                self.azul.network_interfaces.delete(
 | 
			
		||||
                    self.resource_group, nic['name'])
 | 
			
		||||
        for pip in self._listPublicIPAddresses():
 | 
			
		||||
            node_id = self._metadataMatches(pip, metadata)
 | 
			
		||||
            if (node_id and node_id not in known_nodes):
 | 
			
		||||
                self.log.info(f"Deleting leaked pip: {pip['name']}")
 | 
			
		||||
                self.azul.public_ip_addresses.delete(
 | 
			
		||||
                    self.resource_group, pip['name'])
 | 
			
		||||
        for disk in self._listDisks():
 | 
			
		||||
            node_id = self._metadataMatches(disk, metadata)
 | 
			
		||||
            if (node_id and node_id not in known_nodes):
 | 
			
		||||
                self.log.info(f"Deleting leaked disk: {disk['name']}")
 | 
			
		||||
                self.azul.disks.delete(self.resource_group, disk['name'])
 | 
			
		||||
 | 
			
		||||
    def listInstances(self):
 | 
			
		||||
        for vm in self._listVirtualMachines():
 | 
			
		||||
            yield AzureInstance(vm)
 | 
			
		||||
 | 
			
		||||
    def getQuotaLimits(self):
 | 
			
		||||
        return QuotaInformation(default=math.inf)
 | 
			
		||||
 | 
			
		||||
    def getQuotaForLabel(self, label):
 | 
			
		||||
        return QuotaInformation(instances=1)
 | 
			
		||||
 | 
			
		||||
    # Local implementation below
 | 
			
		||||
 | 
			
		||||
    def _metadataMatches(self, obj, metadata):
 | 
			
		||||
        if 'tags' not in obj:
 | 
			
		||||
            return None
 | 
			
		||||
        for k, v in metadata.items():
 | 
			
		||||
            if obj['tags'].get(k) != v:
 | 
			
		||||
                return None
 | 
			
		||||
        return obj['tags']['nodepool_node_id']
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _succeeded(obj):
 | 
			
		||||
        return obj['properties']['provisioningState'] == 'Succeeded'
 | 
			
		||||
 | 
			
		||||
    def _refresh(self, obj, force=False):
 | 
			
		||||
        if self._succeeded(obj) and not force:
 | 
			
		||||
            return obj
 | 
			
		||||
 | 
			
		||||
        if obj['type'] == 'Microsoft.Network/publicIPAddresses':
 | 
			
		||||
            l = self._listPublicIPAddresses()
 | 
			
		||||
        if obj['type'] == 'Microsoft.Network/networkInterfaces':
 | 
			
		||||
            l = self._listNetworkInterfaces()
 | 
			
		||||
        if obj['type'] == 'Microsoft.Compute/virtualMachines':
 | 
			
		||||
            l = self._listVirtualMachines()
 | 
			
		||||
 | 
			
		||||
        for new_obj in l:
 | 
			
		||||
            if new_obj['id'] == obj['id']:
 | 
			
		||||
                return new_obj
 | 
			
		||||
        return obj
 | 
			
		||||
 | 
			
		||||
    def _refresh_delete(self, obj):
 | 
			
		||||
        if obj is None:
 | 
			
		||||
            return obj
 | 
			
		||||
 | 
			
		||||
        if obj['type'] == 'Microsoft.Network/publicIPAddresses':
 | 
			
		||||
            l = self._listPublicIPAddresses()
 | 
			
		||||
        if obj['type'] == 'Microsoft.Network/networkInterfaces':
 | 
			
		||||
            l = self._listNetworkInterfaces()
 | 
			
		||||
        if obj['type'] == 'Microsoft.Compute/virtualMachines':
 | 
			
		||||
            l = self._listVirtualMachines()
 | 
			
		||||
 | 
			
		||||
        for new_obj in l:
 | 
			
		||||
            if new_obj['id'] == obj['id']:
 | 
			
		||||
                return new_obj
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    @cachetools.func.ttl_cache(maxsize=1, ttl=10)
 | 
			
		||||
    def _listPublicIPAddresses(self):
 | 
			
		||||
        return self.azul.public_ip_addresses.list(self.resource_group)
 | 
			
		||||
 | 
			
		||||
    def _createPublicIPAddress(self, tags, hostname):
 | 
			
		||||
        v4_params_create = {
 | 
			
		||||
            'location': self.provider.location,
 | 
			
		||||
            'tags': tags,
 | 
			
		||||
            'properties': {
 | 
			
		||||
                'publicIpAllocationMethod': 'dynamic',
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
        return self.azul.public_ip_addresses.create(
 | 
			
		||||
            self.resource_group,
 | 
			
		||||
            "%s-nic-pip" % hostname,
 | 
			
		||||
            v4_params_create,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _deletePublicIPAddress(self, name):
 | 
			
		||||
        for pip in self._listPublicIPAddresses():
 | 
			
		||||
            if pip['name'] == name:
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
        self.azul.public_ip_addresses.delete(self.resource_group, name)
 | 
			
		||||
        return pip
 | 
			
		||||
 | 
			
		||||
    @cachetools.func.ttl_cache(maxsize=1, ttl=10)
 | 
			
		||||
    def _listNetworkInterfaces(self):
 | 
			
		||||
        return self.azul.network_interfaces.list(self.resource_group)
 | 
			
		||||
 | 
			
		||||
    def _createNetworkInterface(self, tags, hostname, pip):
 | 
			
		||||
        nic_data = {
 | 
			
		||||
            'location': self.provider.location,
 | 
			
		||||
            'tags': tags,
 | 
			
		||||
            'properties': {
 | 
			
		||||
                'ipConfigurations': [{
 | 
			
		||||
                    'name': "nodepool-v4-ip-config",
 | 
			
		||||
                    'properties': {
 | 
			
		||||
                        'privateIpAddressVersion': 'IPv4',
 | 
			
		||||
                        'subnet': {
 | 
			
		||||
                            'id': self.provider.subnet_id
 | 
			
		||||
                        },
 | 
			
		||||
                        'publicIpAddress': {
 | 
			
		||||
                            'id': pip['id']
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }]
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if self.provider.ipv6:
 | 
			
		||||
            nic_data['properties']['ipConfigurations'].append({
 | 
			
		||||
                'name': "nodepool-v6-ip-config",
 | 
			
		||||
                'properties': {
 | 
			
		||||
                    'privateIpAddressVersion': 'IPv6',
 | 
			
		||||
                    'subnet': {
 | 
			
		||||
                        'id': self.provider.subnet_id
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        return self.azul.network_interfaces.create(
 | 
			
		||||
            self.resource_group,
 | 
			
		||||
            "%s-nic" % hostname,
 | 
			
		||||
            nic_data
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _deleteNetworkInterface(self, name):
 | 
			
		||||
        for nic in self._listNetworkInterfaces():
 | 
			
		||||
            if nic['name'] == name:
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
        self.azul.network_interfaces.delete(self.resource_group, name)
 | 
			
		||||
        return nic
 | 
			
		||||
 | 
			
		||||
    @cachetools.func.ttl_cache(maxsize=1, ttl=10)
 | 
			
		||||
    def _listVirtualMachines(self):
 | 
			
		||||
        return self.azul.virtual_machines.list(self.resource_group)
 | 
			
		||||
 | 
			
		||||
    def _createVirtualMachine(self, label, tags, hostname, nic):
 | 
			
		||||
        return self.azul.virtual_machines.create(
 | 
			
		||||
            self.resource_group, hostname, {
 | 
			
		||||
                'location': self.provider.location,
 | 
			
		||||
                'tags': tags,
 | 
			
		||||
                'properties': {
 | 
			
		||||
                    'osProfile': {
 | 
			
		||||
                        'computerName': hostname,
 | 
			
		||||
                        'adminUsername': label.cloud_image.username,
 | 
			
		||||
                        'linuxConfiguration': {
 | 
			
		||||
                            'ssh': {
 | 
			
		||||
                                'publicKeys': [{
 | 
			
		||||
                                    'path': "/home/%s/.ssh/authorized_keys" % (
 | 
			
		||||
                                        label.cloud_image.username),
 | 
			
		||||
                                    'keyData': label.cloud_image.key,
 | 
			
		||||
                                }]
 | 
			
		||||
                            },
 | 
			
		||||
                            "disablePasswordAuthentication": True,
 | 
			
		||||
                        }
 | 
			
		||||
                    },
 | 
			
		||||
                    'hardwareProfile': {
 | 
			
		||||
                        'vmSize': label.hardware_profile["vm-size"]
 | 
			
		||||
                    },
 | 
			
		||||
                    'storageProfile': {
 | 
			
		||||
                        'imageReference': label.cloud_image.image_reference
 | 
			
		||||
                    },
 | 
			
		||||
                    'networkProfile': {
 | 
			
		||||
                        'networkInterfaces': [{
 | 
			
		||||
                            'id': nic['id'],
 | 
			
		||||
                            'properties': {
 | 
			
		||||
                                'primary': True,
 | 
			
		||||
                            }
 | 
			
		||||
                        }]
 | 
			
		||||
                    },
 | 
			
		||||
                },
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
    def _deleteVirtualMachine(self, name):
 | 
			
		||||
        for vm in self._listVirtualMachines():
 | 
			
		||||
            if vm['name'] == name:
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
        self.azul.virtual_machines.delete(self.resource_group, name)
 | 
			
		||||
        return vm
 | 
			
		||||
 | 
			
		||||
    @cachetools.func.ttl_cache(maxsize=1, ttl=10)
 | 
			
		||||
    def _listDisks(self):
 | 
			
		||||
        return self.azul.disks.list(self.resource_group)
 | 
			
		||||
 | 
			
		||||
    def _deleteDisk(self, name):
 | 
			
		||||
        for disk in self._listNetworkInterfaces():
 | 
			
		||||
            if disk['name'] == name:
 | 
			
		||||
                break
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
        self.azul.disks.delete(self.resource_group, name)
 | 
			
		||||
        return disk
 | 
			
		||||
							
								
								
									
										269
									
								
								nodepool/driver/azurestate/azul.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										269
									
								
								nodepool/driver/azurestate/azul.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,269 @@
 | 
			
		||||
# Copyright 2021 Acme Gating, LLC
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
			
		||||
# not use this file except in compliance with the License. You may obtain
 | 
			
		||||
# a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
# License for the specific language governing permissions and limitations
 | 
			
		||||
# under the License.
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import logging
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureAuth(requests.auth.AuthBase):
 | 
			
		||||
    AUTH_URL = "https://login.microsoftonline.com/{tenantId}/oauth2/token"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, credential):
 | 
			
		||||
        self.log = logging.getLogger("azul.auth")
 | 
			
		||||
        self.credential = credential
 | 
			
		||||
        self.token = None
 | 
			
		||||
        self.expiration = time.time()
 | 
			
		||||
 | 
			
		||||
    def refresh(self):
 | 
			
		||||
        if self.expiration - time.time() < 60:
 | 
			
		||||
            self.log.debug('Refreshing authentication token')
 | 
			
		||||
            url = self.AUTH_URL.format(**self.credential)
 | 
			
		||||
            data = {
 | 
			
		||||
                'grant_type': 'client_credentials',
 | 
			
		||||
                'client_id': self.credential['clientId'],
 | 
			
		||||
                'client_secret': self.credential['clientSecret'],
 | 
			
		||||
                'resource': 'https://management.azure.com/',
 | 
			
		||||
            }
 | 
			
		||||
            r = requests.post(url, data)
 | 
			
		||||
            ret = r.json()
 | 
			
		||||
            self.token = ret['access_token']
 | 
			
		||||
            self.expiration = float(ret['expires_on'])
 | 
			
		||||
 | 
			
		||||
    def __call__(self, r):
 | 
			
		||||
        self.refresh()
 | 
			
		||||
        r.headers["authorization"] = "Bearer " + self.token
 | 
			
		||||
        return r
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureError(Exception):
 | 
			
		||||
    def __init__(self, status_code, message):
 | 
			
		||||
        super().__init__(message)
 | 
			
		||||
        self.status_code = status_code
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureNotFoundError(AzureError):
 | 
			
		||||
    def __init__(self, status_code, message):
 | 
			
		||||
        super().__init__(status_code, message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureResourceGroupsCRUD:
 | 
			
		||||
    def __init__(self, cloud, version):
 | 
			
		||||
        self.cloud = cloud
 | 
			
		||||
        self.version = version
 | 
			
		||||
 | 
			
		||||
    def url(self, url, **args):
 | 
			
		||||
        base_url = (
 | 
			
		||||
            'https://management.azure.com/subscriptions/{subscriptionId}'
 | 
			
		||||
            '/resourcegroups/')
 | 
			
		||||
        url = base_url + url + '?api-version={apiVersion}'
 | 
			
		||||
        args = args.copy()
 | 
			
		||||
        args.update(self.cloud.credential)
 | 
			
		||||
        args['apiVersion'] = self.version
 | 
			
		||||
        return url.format(**args)
 | 
			
		||||
 | 
			
		||||
    def list(self):
 | 
			
		||||
        url = self.url('')
 | 
			
		||||
        return self.cloud.paginate(self.cloud.get(url))
 | 
			
		||||
 | 
			
		||||
    def get(self, name):
 | 
			
		||||
        url = self.url(name)
 | 
			
		||||
        return self.cloud.get(url)
 | 
			
		||||
 | 
			
		||||
    def create(self, name, params):
 | 
			
		||||
        url = self.url(name)
 | 
			
		||||
        return self.cloud.put(url, params)
 | 
			
		||||
 | 
			
		||||
    def delete(self, name):
 | 
			
		||||
        url = self.url(name)
 | 
			
		||||
        return self.cloud.delete(url)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureCRUD:
 | 
			
		||||
    def __init__(self, cloud, resource, version):
 | 
			
		||||
        self.cloud = cloud
 | 
			
		||||
        self.resource = resource
 | 
			
		||||
        self.version = version
 | 
			
		||||
 | 
			
		||||
    def url(self, url, **args):
 | 
			
		||||
        base_url = (
 | 
			
		||||
            'https://management.azure.com/subscriptions/{subscriptionId}'
 | 
			
		||||
            '/resourceGroups/{resourceGroupName}/providers/')
 | 
			
		||||
        url = base_url + url + '?api-version={apiVersion}'
 | 
			
		||||
        args = args.copy()
 | 
			
		||||
        args.update(self.cloud.credential)
 | 
			
		||||
        args['apiVersion'] = self.version
 | 
			
		||||
        return url.format(**args)
 | 
			
		||||
 | 
			
		||||
    def id_url(self, url, **args):
 | 
			
		||||
        base_url = 'https://management.azure.com'
 | 
			
		||||
        url = base_url + url + '?api-version={apiVersion}'
 | 
			
		||||
        args = args.copy()
 | 
			
		||||
        args['apiVersion'] = self.version
 | 
			
		||||
        return url.format(**args)
 | 
			
		||||
 | 
			
		||||
    def list(self, resource_group_name):
 | 
			
		||||
        url = self.url(
 | 
			
		||||
            self.resource,
 | 
			
		||||
            resourceGroupName=resource_group_name,
 | 
			
		||||
        )
 | 
			
		||||
        return self.cloud.paginate(self.cloud.get(url))
 | 
			
		||||
 | 
			
		||||
    def get_by_id(self, resource_id):
 | 
			
		||||
        url = self.id_url(resource_id)
 | 
			
		||||
        return self.cloud.get(url)
 | 
			
		||||
 | 
			
		||||
    def get(self, resource_group_name, name):
 | 
			
		||||
        url = self.url(
 | 
			
		||||
            '{_resource}/{_resourceName}',
 | 
			
		||||
            _resource=self.resource,
 | 
			
		||||
            _resourceName=name,
 | 
			
		||||
            resourceGroupName=resource_group_name,
 | 
			
		||||
        )
 | 
			
		||||
        return self.cloud.get(url)
 | 
			
		||||
 | 
			
		||||
    def create(self, resource_group_name, name, params):
 | 
			
		||||
        url = self.url(
 | 
			
		||||
            '{_resource}/{_resourceName}',
 | 
			
		||||
            _resource=self.resource,
 | 
			
		||||
            _resourceName=name,
 | 
			
		||||
            resourceGroupName=resource_group_name,
 | 
			
		||||
        )
 | 
			
		||||
        return self.cloud.put(url, params)
 | 
			
		||||
 | 
			
		||||
    def delete(self, resource_group_name, name):
 | 
			
		||||
        url = self.url(
 | 
			
		||||
            '{_resource}/{_resourceName}',
 | 
			
		||||
            _resource=self.resource,
 | 
			
		||||
            _resourceName=name,
 | 
			
		||||
            resourceGroupName=resource_group_name,
 | 
			
		||||
        )
 | 
			
		||||
        return self.cloud.delete(url)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureDictResponse(dict):
 | 
			
		||||
    def __init__(self, response, *args):
 | 
			
		||||
        super().__init__(*args)
 | 
			
		||||
        self.response = response
 | 
			
		||||
        self.last_retry = time.time()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureListResponse(list):
 | 
			
		||||
    def __init__(self, response, *args):
 | 
			
		||||
        super().__init__(*args)
 | 
			
		||||
        self.response = response
 | 
			
		||||
        self.last_retry = time.time()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureCloud:
 | 
			
		||||
    TIMEOUT = 60
 | 
			
		||||
 | 
			
		||||
    def __init__(self, credential):
 | 
			
		||||
        self.credential = credential
 | 
			
		||||
        self.session = requests.Session()
 | 
			
		||||
        self.log = logging.getLogger("azul")
 | 
			
		||||
        self.auth = AzureAuth(credential)
 | 
			
		||||
        self.network_interfaces = AzureCRUD(
 | 
			
		||||
            self,
 | 
			
		||||
            'Microsoft.Network/networkInterfaces',
 | 
			
		||||
            '2020-07-01')
 | 
			
		||||
        self.public_ip_addresses = AzureCRUD(
 | 
			
		||||
            self,
 | 
			
		||||
            'Microsoft.Network/publicIPAddresses',
 | 
			
		||||
            '2020-07-01')
 | 
			
		||||
        self.virtual_machines = AzureCRUD(
 | 
			
		||||
            self,
 | 
			
		||||
            'Microsoft.Compute/virtualMachines',
 | 
			
		||||
            '2020-12-01')
 | 
			
		||||
        self.disks = AzureCRUD(
 | 
			
		||||
            self,
 | 
			
		||||
            'Microsoft.Compute/disks',
 | 
			
		||||
            '2020-06-30')
 | 
			
		||||
        self.resource_groups = AzureResourceGroupsCRUD(
 | 
			
		||||
            self,
 | 
			
		||||
            '2020-06-01')
 | 
			
		||||
 | 
			
		||||
    def get(self, url, codes=[200]):
 | 
			
		||||
        return self.request('GET', url, None, codes)
 | 
			
		||||
 | 
			
		||||
    def put(self, url, data, codes=[200, 201]):
 | 
			
		||||
        return self.request('PUT', url, data, codes)
 | 
			
		||||
 | 
			
		||||
    def delete(self, url, codes=[200, 201, 202, 204]):
 | 
			
		||||
        return self.request('DELETE', url, None, codes)
 | 
			
		||||
 | 
			
		||||
    def request(self, method, url, data, codes):
 | 
			
		||||
        self.log.debug('%s: %s %s' % (method, url, data))
 | 
			
		||||
        response = self.session.request(
 | 
			
		||||
            method, url, json=data,
 | 
			
		||||
            auth=self.auth, timeout=self.TIMEOUT,
 | 
			
		||||
            headers={'Accept': 'application/json',
 | 
			
		||||
                     'Accept-Encoding': 'gzip'})
 | 
			
		||||
 | 
			
		||||
        self.log.debug("Received headers: %s", response.headers)
 | 
			
		||||
        if response.status_code in codes:
 | 
			
		||||
            if len(response.text):
 | 
			
		||||
                self.log.debug("Received: %s", response.text)
 | 
			
		||||
                ret_data = response.json()
 | 
			
		||||
                if isinstance(ret_data, list):
 | 
			
		||||
                    return AzureListResponse(response, ret_data)
 | 
			
		||||
                else:
 | 
			
		||||
                    return AzureDictResponse(response, ret_data)
 | 
			
		||||
            self.log.debug("Empty response")
 | 
			
		||||
            return AzureDictResponse(response, {})
 | 
			
		||||
        err = response.json()
 | 
			
		||||
        self.log.error(response.text)
 | 
			
		||||
        if response.status_code == 404:
 | 
			
		||||
            raise AzureNotFoundError(
 | 
			
		||||
                response.status_code, err['error']['message'])
 | 
			
		||||
        else:
 | 
			
		||||
            raise AzureError(response.status_code, err['error']['message'])
 | 
			
		||||
 | 
			
		||||
    def paginate(self, data):
 | 
			
		||||
        ret = data['value']
 | 
			
		||||
        while 'nextLink' in data:
 | 
			
		||||
            data = self.get(data['nextLink'])
 | 
			
		||||
            ret += data['value']
 | 
			
		||||
        return ret
 | 
			
		||||
 | 
			
		||||
    def check_async_operation(self, response):
 | 
			
		||||
        resp = response.response
 | 
			
		||||
        location = resp.headers.get(
 | 
			
		||||
            'Azure-AsyncOperation',
 | 
			
		||||
            resp.headers.get('Location', None))
 | 
			
		||||
        if not location:
 | 
			
		||||
            self.log.debug("No async operation found")
 | 
			
		||||
            return None
 | 
			
		||||
        remain = (response.last_retry +
 | 
			
		||||
                  float(resp.headers.get('Retry-After', 2))) - time.time()
 | 
			
		||||
        self.log.debug("remain time %s", remain)
 | 
			
		||||
        if remain > 0:
 | 
			
		||||
            time.sleep(remain)
 | 
			
		||||
        response.last_retry = time.time()
 | 
			
		||||
        return self.get(location)
 | 
			
		||||
 | 
			
		||||
    def wait_for_async_operation(self, response, timeout=600):
 | 
			
		||||
        start = time.time()
 | 
			
		||||
        while True:
 | 
			
		||||
            if time.time() - start > timeout:
 | 
			
		||||
                raise Exception("Timeout waiting for async operation")
 | 
			
		||||
            ret = self.check_async_operation(response)
 | 
			
		||||
            if ret is None:
 | 
			
		||||
                return
 | 
			
		||||
            if ret['status'] == 'InProgress':
 | 
			
		||||
                continue
 | 
			
		||||
            if ret['status'] == 'Succeeded':
 | 
			
		||||
                return
 | 
			
		||||
            raise Exception("Unhandled async operation result: %s",
 | 
			
		||||
                            ret['status'])
 | 
			
		||||
							
								
								
									
										225
									
								
								nodepool/driver/azurestate/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								nodepool/driver/azurestate/config.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,225 @@
 | 
			
		||||
# Copyright 2018 Red Hat
 | 
			
		||||
# Copyright 2021 Acme Gating, LLC
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 | 
			
		||||
# implied.
 | 
			
		||||
#
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import voluptuous as v
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from nodepool.driver import ConfigPool
 | 
			
		||||
from nodepool.driver import ConfigValue
 | 
			
		||||
from nodepool.driver import ProviderConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProviderCloudImage(ConfigValue):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.name = None
 | 
			
		||||
        self.image_id = None
 | 
			
		||||
        self.username = None
 | 
			
		||||
        self.key = None
 | 
			
		||||
        self.python_path = None
 | 
			
		||||
        self.connection_type = None
 | 
			
		||||
        self.connection_port = None
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        if isinstance(other, ProviderCloudImage):
 | 
			
		||||
            return (self.name == other.name
 | 
			
		||||
                    and self.image_id == other.image_id
 | 
			
		||||
                    and self.username == other.username
 | 
			
		||||
                    and self.key == other.key
 | 
			
		||||
                    and self.python_path == other.python_path
 | 
			
		||||
                    and self.connection_type == other.connection_type
 | 
			
		||||
                    and self.connection_port == other.connection_port)
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "<ProviderCloudImage %s>" % self.name
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def external_name(self):
 | 
			
		||||
        '''Human readable version of external.'''
 | 
			
		||||
        return self.image_id or self.name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureLabel(ConfigValue):
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        if (  # other.username != self.username or
 | 
			
		||||
              # other.imageReference != self.imageReference or
 | 
			
		||||
            other.hardware_profile != self.hardware_profile):
 | 
			
		||||
            return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzurePool(ConfigPool):
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        if other.labels != self.labels:
 | 
			
		||||
            return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "<AzurePool %s>" % self.name
 | 
			
		||||
 | 
			
		||||
    def load(self, pool_config):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AzureProviderConfig(ProviderConfig):
 | 
			
		||||
    def __init__(self, driver, provider):
 | 
			
		||||
        self._pools = {}
 | 
			
		||||
        self.driver_object = driver
 | 
			
		||||
        self.rate_limit = None
 | 
			
		||||
        self.launch_retries = None
 | 
			
		||||
        super().__init__(provider)
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        if (other.location != self.location or
 | 
			
		||||
            other.pools != self.pools):
 | 
			
		||||
            return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def pools(self):
 | 
			
		||||
        return self._pools
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def manage_images(self):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def reset():
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def load(self, config):
 | 
			
		||||
        self.rate_limit = self.provider.get('rate-limit', 1)
 | 
			
		||||
        self.launch_retries = self.provider.get('launch-retries', 3)
 | 
			
		||||
        self.boot_timeout = self.provider.get('boot-timeout', 60)
 | 
			
		||||
 | 
			
		||||
        self.zuul_public_key = self.provider['zuul-public-key']
 | 
			
		||||
        self.location = self.provider['location']
 | 
			
		||||
        self.subnet_id = self.provider['subnet-id']
 | 
			
		||||
        self.ipv6 = self.provider.get('ipv6', False)
 | 
			
		||||
        self.resource_group = self.provider['resource-group']
 | 
			
		||||
        self.resource_group_location = self.provider['resource-group-location']
 | 
			
		||||
        self.auth_path = self.provider.get(
 | 
			
		||||
            'auth-path', os.getenv('AZURE_AUTH_LOCATION', None))
 | 
			
		||||
 | 
			
		||||
        default_port_mapping = {
 | 
			
		||||
            'ssh': 22,
 | 
			
		||||
            'winrm': 5986,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.cloud_images = {}
 | 
			
		||||
        for image in self.provider['cloud-images']:
 | 
			
		||||
            i = ProviderCloudImage()
 | 
			
		||||
            i.name = image['name']
 | 
			
		||||
            i.username = image['username']
 | 
			
		||||
            # i.key = image['key']
 | 
			
		||||
            i.key = self.zuul_public_key
 | 
			
		||||
            i.image_reference = image['image-reference']
 | 
			
		||||
            i.connection_type = image.get('connection-type', 'ssh')
 | 
			
		||||
            i.connection_port = image.get(
 | 
			
		||||
                'connection-port',
 | 
			
		||||
                default_port_mapping.get(i.connection_type, 22))
 | 
			
		||||
            self.cloud_images[i.name] = i
 | 
			
		||||
 | 
			
		||||
        for pool in self.provider.get('pools', []):
 | 
			
		||||
            pp = AzurePool()
 | 
			
		||||
            pp.name = pool['name']
 | 
			
		||||
            pp.provider = self
 | 
			
		||||
            pp.max_servers = pool['max-servers']
 | 
			
		||||
            pp.use_internal_ip = bool(pool.get('use-internal-ip', False))
 | 
			
		||||
            pp.host_key_checking = bool(pool.get(
 | 
			
		||||
                'host-key-checking', True))
 | 
			
		||||
            self._pools[pp.name] = pp
 | 
			
		||||
            pp.labels = {}
 | 
			
		||||
 | 
			
		||||
            for label in pool.get('labels', []):
 | 
			
		||||
                pl = AzureLabel()
 | 
			
		||||
                pl.name = label['name']
 | 
			
		||||
                pl.pool = pp
 | 
			
		||||
                pp.labels[pl.name] = pl
 | 
			
		||||
 | 
			
		||||
                cloud_image_name = label['cloud-image']
 | 
			
		||||
                if cloud_image_name:
 | 
			
		||||
                    cloud_image = self.cloud_images.get(
 | 
			
		||||
                        cloud_image_name, None)
 | 
			
		||||
                    if not cloud_image:
 | 
			
		||||
                        raise ValueError(
 | 
			
		||||
                            "cloud-image %s does not exist in provider %s"
 | 
			
		||||
                            " but is referenced in label %s" %
 | 
			
		||||
                            (cloud_image_name, self.name, pl.name))
 | 
			
		||||
                    # pl.imageReference = cloud_image['image-reference']
 | 
			
		||||
                    pl.cloud_image = cloud_image
 | 
			
		||||
                    # pl.username = cloud_image.get('username', 'zuul')
 | 
			
		||||
                else:
 | 
			
		||||
                    # pl.imageReference = None
 | 
			
		||||
                    pl.cloud_image = None
 | 
			
		||||
                    # pl.username = 'zuul'
 | 
			
		||||
 | 
			
		||||
                pl.hardware_profile = label['hardware-profile']
 | 
			
		||||
 | 
			
		||||
                config.labels[label['name']].pools.append(pp)
 | 
			
		||||
                pl.tags = label.get('tags', {})
 | 
			
		||||
 | 
			
		||||
    def getSchema(self):
 | 
			
		||||
 | 
			
		||||
        azure_image_reference = {
 | 
			
		||||
            v.Required('sku'): str,
 | 
			
		||||
            v.Required('publisher'): str,
 | 
			
		||||
            v.Required('version'): str,
 | 
			
		||||
            v.Required('offer'): str,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        azure_hardware_profile = {
 | 
			
		||||
            v.Required('vm-size'): str,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        provider_cloud_images = {
 | 
			
		||||
            v.Required('name'): str,
 | 
			
		||||
            'username': str,
 | 
			
		||||
            v.Required('image-reference'): azure_image_reference,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        azure_label = {
 | 
			
		||||
            v.Required('name'): str,
 | 
			
		||||
            v.Required('hardware-profile'): azure_hardware_profile,
 | 
			
		||||
            v.Required('cloud-image'): str,
 | 
			
		||||
            v.Optional('tags'): dict,
 | 
			
		||||
        }
 | 
			
		||||
        pool = ConfigPool.getCommonSchemaDict()
 | 
			
		||||
        pool.update({
 | 
			
		||||
            v.Required('name'): str,
 | 
			
		||||
            v.Required('labels'): [azure_label],
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        provider = ProviderConfig.getCommonSchemaDict()
 | 
			
		||||
        provider.update({
 | 
			
		||||
            v.Required('zuul-public-key'): str,
 | 
			
		||||
            v.Required('pools'): [pool],
 | 
			
		||||
            v.Required('location'): str,
 | 
			
		||||
            v.Required('resource-group'): str,
 | 
			
		||||
            v.Required('resource-group-location'): str,
 | 
			
		||||
            v.Required('subnet-id'): str,
 | 
			
		||||
            v.Required('cloud-images'): [provider_cloud_images],
 | 
			
		||||
            v.Optional('auth-path'): str,
 | 
			
		||||
        })
 | 
			
		||||
        return v.Schema(provider)
 | 
			
		||||
 | 
			
		||||
    def getSupportedLabels(self, pool_name=None):
 | 
			
		||||
        labels = set()
 | 
			
		||||
        for pool in self._pools.values():
 | 
			
		||||
            if not pool_name or (pool.name == pool_name):
 | 
			
		||||
                labels.update(pool.labels.keys())
 | 
			
		||||
        return labels
 | 
			
		||||
@@ -157,8 +157,9 @@ class StateMachineNodeLauncher(stats.StatsReporter):
 | 
			
		||||
                    raise Exception("Driver implementation error: state "
 | 
			
		||||
                                    "machine must produce external ID "
 | 
			
		||||
                                    "after first advancement")
 | 
			
		||||
                self.updateNodeFromInstance(instance)
 | 
			
		||||
            if state_machine.complete:
 | 
			
		||||
                node.external_id = state_machine.external_id
 | 
			
		||||
                self.zk.storeNode(node)
 | 
			
		||||
            if state_machine.complete and not self.keyscan_future:
 | 
			
		||||
                self.log.debug("Submitting keyscan request")
 | 
			
		||||
                self.updateNodeFromInstance(instance)
 | 
			
		||||
                future = self.manager.keyscan_worker.submit(
 | 
			
		||||
@@ -252,14 +253,16 @@ class StateMachineNodeDeleter:
 | 
			
		||||
            if node.external_id:
 | 
			
		||||
                state_machine.advance()
 | 
			
		||||
                self.log.debug(f"State machine for {node.id} at "
 | 
			
		||||
                               "{state_machine.state}")
 | 
			
		||||
                               f"{state_machine.state}")
 | 
			
		||||
            else:
 | 
			
		||||
                self.state_machine.complete = True
 | 
			
		||||
 | 
			
		||||
            if not self.state_machine.complete:
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        except exceptions.NotFound:
 | 
			
		||||
            self.log.info(f"Instance {node.external_id} not found in "
 | 
			
		||||
                          "provider {node.provider}")
 | 
			
		||||
                          f"provider {node.provider}")
 | 
			
		||||
        except Exception:
 | 
			
		||||
            self.log.exception("Exception deleting instance "
 | 
			
		||||
                               f"{node.external_id} from {node.provider}:")
 | 
			
		||||
@@ -423,15 +426,21 @@ class StateMachineProvider(Provider, QuotaSupport):
 | 
			
		||||
        self.state_machine_thread.start()
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self.log.debug("Stopping")
 | 
			
		||||
        self.running = False
 | 
			
		||||
        if self.state_machine_thread:
 | 
			
		||||
            while self.launchers or self.deleters:
 | 
			
		||||
                time.sleep(1)
 | 
			
		||||
            self.running = False
 | 
			
		||||
        if self.keyscan_worker:
 | 
			
		||||
            self.keyscan_worker.shutdown()
 | 
			
		||||
        if self.state_machine_thread:
 | 
			
		||||
            self.running = False
 | 
			
		||||
        self.log.debug("Stopped")
 | 
			
		||||
 | 
			
		||||
    def join(self):
 | 
			
		||||
        self.log.debug("Joining")
 | 
			
		||||
        if self.state_machine_thread:
 | 
			
		||||
            self.state_machine_thread.join()
 | 
			
		||||
        self.log.debug("Joined")
 | 
			
		||||
 | 
			
		||||
    def _runStateMachines(self):
 | 
			
		||||
        while self.running:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user