Merge "Add azure state machine driver"

This commit is contained in:
Zuul 2021-06-11 14:20:03 +00:00 committed by Gerrit Code Review
commit fa75f3b897
5 changed files with 952 additions and 6 deletions

View File

@ -0,0 +1,25 @@
# Copyright 2021 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from nodepool.driver.statemachine import StateMachineDriver
from nodepool.driver.azurestate.config import AzureProviderConfig
from nodepool.driver.azurestate.adapter import AzureAdapter
class AzureDriver(StateMachineDriver):
def getProviderConfig(self, provider):
return AzureProviderConfig(self, provider)
def getAdapter(self, provider_config):
return AzureAdapter(provider_config)

View File

@ -0,0 +1,418 @@
# Copyright 2021 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import math
import logging
import json
import cachetools.func
from nodepool.driver.utils import QuotaInformation, RateLimiter
from nodepool.driver import statemachine
from . import azul
class AzureInstance(statemachine.Instance):
def __init__(self, vm, nic=None, pip=None):
self.external_id = vm['name']
self.metadata = vm['tags'] or {}
self.private_ipv4 = None
self.public_ipv4 = None
self.public_ipv6 = None
if nic:
for ip_config_data in nic['properties']['ipConfigurations']:
ip_config_prop = ip_config_data['properties']
if ip_config_prop['privateIPAddressVersion'] == 'IPv4':
self.private_ipv4 = ip_config_prop['privateIPAddress']
# public_ipv6
if pip:
self.public_ipv4 = pip['properties'].get('ipAddress')
self.interface_ip = self.public_ipv4 or self.private_ipv4
self.region = vm['location']
self.az = ''
class AzureDeleteStateMachine(statemachine.StateMachine):
VM_DELETING = 'deleting vm'
NIC_DELETING = 'deleting nic'
PIP_DELETING = 'deleting pip'
DISK_DELETING = 'deleting disk'
COMPLETE = 'complete'
def __init__(self, adapter, external_id):
super().__init__()
self.adapter = adapter
self.external_id = external_id
def advance(self):
if self.state == self.START:
self.vm = self.adapter._deleteVirtualMachine(
self.external_id)
self.state = self.VM_DELETING
if self.state == self.VM_DELETING:
self.vm = self.adapter._refresh_delete(self.vm)
if self.vm is None:
self.nic = self.adapter._deleteNetworkInterface(
self.external_id + '-nic')
self.state = self.NIC_DELETING
if self.state == self.NIC_DELETING:
self.nic = self.adapter._refresh_delete(self.nic)
if self.nic is None:
self.pip = self.adapter._deletePublicIPAddress(
self.external_id + '-nic-pip')
self.state = self.PIP_DELETING
if self.state == self.PIP_DELETING:
self.pip = self.adapter._refresh_delete(self.pip)
if self.pip is None:
self.disks = []
for disk in self.adapter._listDisks():
if disk['tags'] is not None and \
disk['tags'].get('nodepool_id') == self.external_id:
disk = self.adapter._deleteDisk(disk['name'])
self.disks.append(disk)
self.state = self.DISK_DELETING
if self.state == self.DISK_DELETING:
all_deleted = True
for disk in self.disks:
disk = self.adapter._refresh_delete(disk)
if disk:
all_deleted = False
if all_deleted:
self.state = self.COMPLETE
self.complete = True
class AzureCreateStateMachine(statemachine.StateMachine):
PIP_CREATING = 'creating pip'
NIC_CREATING = 'creating nic'
VM_CREATING = 'creating vm'
NIC_QUERY = 'querying nic'
PIP_QUERY = 'querying pip'
COMPLETE = 'complete'
def __init__(self, adapter, hostname, label, metadata, retries):
super().__init__()
self.adapter = adapter
self.retries = retries
self.metadata = metadata
self.tags = label.tags or {}
self.tags.update(metadata)
self.hostname = hostname
self.label = label
self.pip = None
self.nic = None
self.vm = None
def advance(self):
if self.state == self.START:
self.pip = self.adapter._createPublicIPAddress(
self.tags, self.hostname)
self.state = self.PIP_CREATING
self.external_id = self.hostname
if self.state == self.PIP_CREATING:
self.pip = self.adapter._refresh(self.pip)
if self.adapter._succeeded(self.pip):
self.nic = self.adapter._createNetworkInterface(
self.tags, self.hostname, self.pip)
self.state = self.NIC_CREATING
else:
return
if self.state == self.NIC_CREATING:
self.nic = self.adapter._refresh(self.nic)
if self.adapter._succeeded(self.nic):
self.vm = self.adapter._createVirtualMachine(
self.label, self.tags, self.hostname, self.nic)
self.state = self.VM_CREATING
else:
return
if self.state == self.VM_CREATING:
self.vm = self.adapter._refresh(self.vm)
# if 404:
# increment retries
# state = self.NIC_CREATING
# if error:
# if retries too big: raise error
# delete vm
if self.adapter._succeeded(self.vm):
self.state = self.NIC_QUERY
else:
return
if self.state == self.NIC_QUERY:
self.nic = self.adapter._refresh(self.nic, force=True)
for ip_config_data in self.nic['properties']['ipConfigurations']:
ip_config_prop = ip_config_data['properties']
if ip_config_prop['privateIPAddressVersion'] == 'IPv4':
if 'privateIPAddress' in ip_config_prop:
self.state = self.PIP_QUERY
if self.state == self.PIP_QUERY:
self.pip = self.adapter._refresh(self.pip, force=True)
if 'ipAddress' in self.pip['properties']:
self.state = self.COMPLETE
if self.state == self.COMPLETE:
self.complete = True
return AzureInstance(self.vm, self.nic, self.pip)
class AzureAdapter(statemachine.Adapter):
log = logging.getLogger("nodepool.driver.azure.AzureAdapter")
def __init__(self, provider_config):
self.provider = provider_config
self.resource_group = self.provider.resource_group
self.resource_group_location = self.provider.resource_group_location
self.rate_limiter = RateLimiter(self.provider.name,
self.provider.rate_limit)
with open(self.provider.auth_path) as f:
self.azul = azul.AzureCloud(json.load(f))
def getCreateStateMachine(self, hostname, label, metadata, retries):
return AzureCreateStateMachine(
self, hostname, label, metadata, retries)
def getDeleteStateMachine(self, external_id):
return AzureDeleteStateMachine(self, external_id)
def cleanupLeakedResources(self, known_nodes, metadata):
for vm in self._listVirtualMachines():
node_id = self._metadataMatches(vm, metadata)
if (node_id and node_id not in known_nodes):
self.log.info(f"Deleting leaked vm: {vm['name']}")
self.azul.virtual_machines.delete(
self.resource_group, vm['name'])
for nic in self._listNetworkInterfaces():
node_id = self._metadataMatches(nic, metadata)
if (node_id and node_id not in known_nodes):
self.log.info(f"Deleting leaked nic: {nic['name']}")
self.azul.network_interfaces.delete(
self.resource_group, nic['name'])
for pip in self._listPublicIPAddresses():
node_id = self._metadataMatches(pip, metadata)
if (node_id and node_id not in known_nodes):
self.log.info(f"Deleting leaked pip: {pip['name']}")
self.azul.public_ip_addresses.delete(
self.resource_group, pip['name'])
for disk in self._listDisks():
node_id = self._metadataMatches(disk, metadata)
if (node_id and node_id not in known_nodes):
self.log.info(f"Deleting leaked disk: {disk['name']}")
self.azul.disks.delete(self.resource_group, disk['name'])
def listInstances(self):
for vm in self._listVirtualMachines():
yield AzureInstance(vm)
def getQuotaLimits(self):
return QuotaInformation(default=math.inf)
def getQuotaForLabel(self, label):
return QuotaInformation(instances=1)
# Local implementation below
def _metadataMatches(self, obj, metadata):
if 'tags' not in obj:
return None
for k, v in metadata.items():
if obj['tags'].get(k) != v:
return None
return obj['tags']['nodepool_node_id']
@staticmethod
def _succeeded(obj):
return obj['properties']['provisioningState'] == 'Succeeded'
def _refresh(self, obj, force=False):
if self._succeeded(obj) and not force:
return obj
if obj['type'] == 'Microsoft.Network/publicIPAddresses':
l = self._listPublicIPAddresses()
if obj['type'] == 'Microsoft.Network/networkInterfaces':
l = self._listNetworkInterfaces()
if obj['type'] == 'Microsoft.Compute/virtualMachines':
l = self._listVirtualMachines()
for new_obj in l:
if new_obj['id'] == obj['id']:
return new_obj
return obj
def _refresh_delete(self, obj):
if obj is None:
return obj
if obj['type'] == 'Microsoft.Network/publicIPAddresses':
l = self._listPublicIPAddresses()
if obj['type'] == 'Microsoft.Network/networkInterfaces':
l = self._listNetworkInterfaces()
if obj['type'] == 'Microsoft.Compute/virtualMachines':
l = self._listVirtualMachines()
for new_obj in l:
if new_obj['id'] == obj['id']:
return new_obj
return None
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listPublicIPAddresses(self):
return self.azul.public_ip_addresses.list(self.resource_group)
def _createPublicIPAddress(self, tags, hostname):
v4_params_create = {
'location': self.provider.location,
'tags': tags,
'properties': {
'publicIpAllocationMethod': 'dynamic',
},
}
return self.azul.public_ip_addresses.create(
self.resource_group,
"%s-nic-pip" % hostname,
v4_params_create,
)
def _deletePublicIPAddress(self, name):
for pip in self._listPublicIPAddresses():
if pip['name'] == name:
break
else:
return None
self.azul.public_ip_addresses.delete(self.resource_group, name)
return pip
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listNetworkInterfaces(self):
return self.azul.network_interfaces.list(self.resource_group)
def _createNetworkInterface(self, tags, hostname, pip):
nic_data = {
'location': self.provider.location,
'tags': tags,
'properties': {
'ipConfigurations': [{
'name': "nodepool-v4-ip-config",
'properties': {
'privateIpAddressVersion': 'IPv4',
'subnet': {
'id': self.provider.subnet_id
},
'publicIpAddress': {
'id': pip['id']
}
}
}]
}
}
if self.provider.ipv6:
nic_data['properties']['ipConfigurations'].append({
'name': "nodepool-v6-ip-config",
'properties': {
'privateIpAddressVersion': 'IPv6',
'subnet': {
'id': self.provider.subnet_id
}
}
})
return self.azul.network_interfaces.create(
self.resource_group,
"%s-nic" % hostname,
nic_data
)
def _deleteNetworkInterface(self, name):
for nic in self._listNetworkInterfaces():
if nic['name'] == name:
break
else:
return None
self.azul.network_interfaces.delete(self.resource_group, name)
return nic
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listVirtualMachines(self):
return self.azul.virtual_machines.list(self.resource_group)
def _createVirtualMachine(self, label, tags, hostname, nic):
return self.azul.virtual_machines.create(
self.resource_group, hostname, {
'location': self.provider.location,
'tags': tags,
'properties': {
'osProfile': {
'computerName': hostname,
'adminUsername': label.cloud_image.username,
'linuxConfiguration': {
'ssh': {
'publicKeys': [{
'path': "/home/%s/.ssh/authorized_keys" % (
label.cloud_image.username),
'keyData': label.cloud_image.key,
}]
},
"disablePasswordAuthentication": True,
}
},
'hardwareProfile': {
'vmSize': label.hardware_profile["vm-size"]
},
'storageProfile': {
'imageReference': label.cloud_image.image_reference
},
'networkProfile': {
'networkInterfaces': [{
'id': nic['id'],
'properties': {
'primary': True,
}
}]
},
},
})
def _deleteVirtualMachine(self, name):
for vm in self._listVirtualMachines():
if vm['name'] == name:
break
else:
return None
self.azul.virtual_machines.delete(self.resource_group, name)
return vm
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listDisks(self):
return self.azul.disks.list(self.resource_group)
def _deleteDisk(self, name):
for disk in self._listNetworkInterfaces():
if disk['name'] == name:
break
else:
return None
self.azul.disks.delete(self.resource_group, name)
return disk

View File

@ -0,0 +1,269 @@
# Copyright 2021 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import requests
import logging
import time
class AzureAuth(requests.auth.AuthBase):
AUTH_URL = "https://login.microsoftonline.com/{tenantId}/oauth2/token"
def __init__(self, credential):
self.log = logging.getLogger("azul.auth")
self.credential = credential
self.token = None
self.expiration = time.time()
def refresh(self):
if self.expiration - time.time() < 60:
self.log.debug('Refreshing authentication token')
url = self.AUTH_URL.format(**self.credential)
data = {
'grant_type': 'client_credentials',
'client_id': self.credential['clientId'],
'client_secret': self.credential['clientSecret'],
'resource': 'https://management.azure.com/',
}
r = requests.post(url, data)
ret = r.json()
self.token = ret['access_token']
self.expiration = float(ret['expires_on'])
def __call__(self, r):
self.refresh()
r.headers["authorization"] = "Bearer " + self.token
return r
class AzureError(Exception):
def __init__(self, status_code, message):
super().__init__(message)
self.status_code = status_code
class AzureNotFoundError(AzureError):
def __init__(self, status_code, message):
super().__init__(status_code, message)
class AzureResourceGroupsCRUD:
def __init__(self, cloud, version):
self.cloud = cloud
self.version = version
def url(self, url, **args):
base_url = (
'https://management.azure.com/subscriptions/{subscriptionId}'
'/resourcegroups/')
url = base_url + url + '?api-version={apiVersion}'
args = args.copy()
args.update(self.cloud.credential)
args['apiVersion'] = self.version
return url.format(**args)
def list(self):
url = self.url('')
return self.cloud.paginate(self.cloud.get(url))
def get(self, name):
url = self.url(name)
return self.cloud.get(url)
def create(self, name, params):
url = self.url(name)
return self.cloud.put(url, params)
def delete(self, name):
url = self.url(name)
return self.cloud.delete(url)
class AzureCRUD:
def __init__(self, cloud, resource, version):
self.cloud = cloud
self.resource = resource
self.version = version
def url(self, url, **args):
base_url = (
'https://management.azure.com/subscriptions/{subscriptionId}'
'/resourceGroups/{resourceGroupName}/providers/')
url = base_url + url + '?api-version={apiVersion}'
args = args.copy()
args.update(self.cloud.credential)
args['apiVersion'] = self.version
return url.format(**args)
def id_url(self, url, **args):
base_url = 'https://management.azure.com'
url = base_url + url + '?api-version={apiVersion}'
args = args.copy()
args['apiVersion'] = self.version
return url.format(**args)
def list(self, resource_group_name):
url = self.url(
self.resource,
resourceGroupName=resource_group_name,
)
return self.cloud.paginate(self.cloud.get(url))
def get_by_id(self, resource_id):
url = self.id_url(resource_id)
return self.cloud.get(url)
def get(self, resource_group_name, name):
url = self.url(
'{_resource}/{_resourceName}',
_resource=self.resource,
_resourceName=name,
resourceGroupName=resource_group_name,
)
return self.cloud.get(url)
def create(self, resource_group_name, name, params):
url = self.url(
'{_resource}/{_resourceName}',
_resource=self.resource,
_resourceName=name,
resourceGroupName=resource_group_name,
)
return self.cloud.put(url, params)
def delete(self, resource_group_name, name):
url = self.url(
'{_resource}/{_resourceName}',
_resource=self.resource,
_resourceName=name,
resourceGroupName=resource_group_name,
)
return self.cloud.delete(url)
class AzureDictResponse(dict):
def __init__(self, response, *args):
super().__init__(*args)
self.response = response
self.last_retry = time.time()
class AzureListResponse(list):
def __init__(self, response, *args):
super().__init__(*args)
self.response = response
self.last_retry = time.time()
class AzureCloud:
TIMEOUT = 60
def __init__(self, credential):
self.credential = credential
self.session = requests.Session()
self.log = logging.getLogger("azul")
self.auth = AzureAuth(credential)
self.network_interfaces = AzureCRUD(
self,
'Microsoft.Network/networkInterfaces',
'2020-07-01')
self.public_ip_addresses = AzureCRUD(
self,
'Microsoft.Network/publicIPAddresses',
'2020-07-01')
self.virtual_machines = AzureCRUD(
self,
'Microsoft.Compute/virtualMachines',
'2020-12-01')
self.disks = AzureCRUD(
self,
'Microsoft.Compute/disks',
'2020-06-30')
self.resource_groups = AzureResourceGroupsCRUD(
self,
'2020-06-01')
def get(self, url, codes=[200]):
return self.request('GET', url, None, codes)
def put(self, url, data, codes=[200, 201]):
return self.request('PUT', url, data, codes)
def delete(self, url, codes=[200, 201, 202, 204]):
return self.request('DELETE', url, None, codes)
def request(self, method, url, data, codes):
self.log.debug('%s: %s %s' % (method, url, data))
response = self.session.request(
method, url, json=data,
auth=self.auth, timeout=self.TIMEOUT,
headers={'Accept': 'application/json',
'Accept-Encoding': 'gzip'})
self.log.debug("Received headers: %s", response.headers)
if response.status_code in codes:
if len(response.text):
self.log.debug("Received: %s", response.text)
ret_data = response.json()
if isinstance(ret_data, list):
return AzureListResponse(response, ret_data)
else:
return AzureDictResponse(response, ret_data)
self.log.debug("Empty response")
return AzureDictResponse(response, {})
err = response.json()
self.log.error(response.text)
if response.status_code == 404:
raise AzureNotFoundError(
response.status_code, err['error']['message'])
else:
raise AzureError(response.status_code, err['error']['message'])
def paginate(self, data):
ret = data['value']
while 'nextLink' in data:
data = self.get(data['nextLink'])
ret += data['value']
return ret
def check_async_operation(self, response):
resp = response.response
location = resp.headers.get(
'Azure-AsyncOperation',
resp.headers.get('Location', None))
if not location:
self.log.debug("No async operation found")
return None
remain = (response.last_retry +
float(resp.headers.get('Retry-After', 2))) - time.time()
self.log.debug("remain time %s", remain)
if remain > 0:
time.sleep(remain)
response.last_retry = time.time()
return self.get(location)
def wait_for_async_operation(self, response, timeout=600):
start = time.time()
while True:
if time.time() - start > timeout:
raise Exception("Timeout waiting for async operation")
ret = self.check_async_operation(response)
if ret is None:
return
if ret['status'] == 'InProgress':
continue
if ret['status'] == 'Succeeded':
return
raise Exception("Unhandled async operation result: %s",
ret['status'])

View File

@ -0,0 +1,225 @@
# Copyright 2018 Red Hat
# Copyright 2021 Acme Gating, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.
import voluptuous as v
import os
from nodepool.driver import ConfigPool
from nodepool.driver import ConfigValue
from nodepool.driver import ProviderConfig
class ProviderCloudImage(ConfigValue):
def __init__(self):
self.name = None
self.image_id = None
self.username = None
self.key = None
self.python_path = None
self.connection_type = None
self.connection_port = None
def __eq__(self, other):
if isinstance(other, ProviderCloudImage):
return (self.name == other.name
and self.image_id == other.image_id
and self.username == other.username
and self.key == other.key
and self.python_path == other.python_path
and self.connection_type == other.connection_type
and self.connection_port == other.connection_port)
return False
def __repr__(self):
return "<ProviderCloudImage %s>" % self.name
@property
def external_name(self):
'''Human readable version of external.'''
return self.image_id or self.name
class AzureLabel(ConfigValue):
def __eq__(self, other):
if ( # other.username != self.username or
# other.imageReference != self.imageReference or
other.hardware_profile != self.hardware_profile):
return False
return True
class AzurePool(ConfigPool):
def __eq__(self, other):
if other.labels != self.labels:
return False
return True
def __repr__(self):
return "<AzurePool %s>" % self.name
def load(self, pool_config):
pass
class AzureProviderConfig(ProviderConfig):
def __init__(self, driver, provider):
self._pools = {}
self.driver_object = driver
self.rate_limit = None
self.launch_retries = None
super().__init__(provider)
def __eq__(self, other):
if (other.location != self.location or
other.pools != self.pools):
return False
return True
@property
def pools(self):
return self._pools
@property
def manage_images(self):
return False
@staticmethod
def reset():
pass
def load(self, config):
self.rate_limit = self.provider.get('rate-limit', 1)
self.launch_retries = self.provider.get('launch-retries', 3)
self.boot_timeout = self.provider.get('boot-timeout', 60)
self.zuul_public_key = self.provider['zuul-public-key']
self.location = self.provider['location']
self.subnet_id = self.provider['subnet-id']
self.ipv6 = self.provider.get('ipv6', False)
self.resource_group = self.provider['resource-group']
self.resource_group_location = self.provider['resource-group-location']
self.auth_path = self.provider.get(
'auth-path', os.getenv('AZURE_AUTH_LOCATION', None))
default_port_mapping = {
'ssh': 22,
'winrm': 5986,
}
self.cloud_images = {}
for image in self.provider['cloud-images']:
i = ProviderCloudImage()
i.name = image['name']
i.username = image['username']
# i.key = image['key']
i.key = self.zuul_public_key
i.image_reference = image['image-reference']
i.connection_type = image.get('connection-type', 'ssh')
i.connection_port = image.get(
'connection-port',
default_port_mapping.get(i.connection_type, 22))
self.cloud_images[i.name] = i
for pool in self.provider.get('pools', []):
pp = AzurePool()
pp.name = pool['name']
pp.provider = self
pp.max_servers = pool['max-servers']
pp.use_internal_ip = bool(pool.get('use-internal-ip', False))
pp.host_key_checking = bool(pool.get(
'host-key-checking', True))
self._pools[pp.name] = pp
pp.labels = {}
for label in pool.get('labels', []):
pl = AzureLabel()
pl.name = label['name']
pl.pool = pp
pp.labels[pl.name] = pl
cloud_image_name = label['cloud-image']
if cloud_image_name:
cloud_image = self.cloud_images.get(
cloud_image_name, None)
if not cloud_image:
raise ValueError(
"cloud-image %s does not exist in provider %s"
" but is referenced in label %s" %
(cloud_image_name, self.name, pl.name))
# pl.imageReference = cloud_image['image-reference']
pl.cloud_image = cloud_image
# pl.username = cloud_image.get('username', 'zuul')
else:
# pl.imageReference = None
pl.cloud_image = None
# pl.username = 'zuul'
pl.hardware_profile = label['hardware-profile']
config.labels[label['name']].pools.append(pp)
pl.tags = label.get('tags', {})
def getSchema(self):
azure_image_reference = {
v.Required('sku'): str,
v.Required('publisher'): str,
v.Required('version'): str,
v.Required('offer'): str,
}
azure_hardware_profile = {
v.Required('vm-size'): str,
}
provider_cloud_images = {
v.Required('name'): str,
'username': str,
v.Required('image-reference'): azure_image_reference,
}
azure_label = {
v.Required('name'): str,
v.Required('hardware-profile'): azure_hardware_profile,
v.Required('cloud-image'): str,
v.Optional('tags'): dict,
}
pool = ConfigPool.getCommonSchemaDict()
pool.update({
v.Required('name'): str,
v.Required('labels'): [azure_label],
})
provider = ProviderConfig.getCommonSchemaDict()
provider.update({
v.Required('zuul-public-key'): str,
v.Required('pools'): [pool],
v.Required('location'): str,
v.Required('resource-group'): str,
v.Required('resource-group-location'): str,
v.Required('subnet-id'): str,
v.Required('cloud-images'): [provider_cloud_images],
v.Optional('auth-path'): str,
})
return v.Schema(provider)
def getSupportedLabels(self, pool_name=None):
labels = set()
for pool in self._pools.values():
if not pool_name or (pool.name == pool_name):
labels.update(pool.labels.keys())
return labels

View File

@ -157,8 +157,9 @@ class StateMachineNodeLauncher(stats.StatsReporter):
raise Exception("Driver implementation error: state "
"machine must produce external ID "
"after first advancement")
self.updateNodeFromInstance(instance)
if state_machine.complete:
node.external_id = state_machine.external_id
self.zk.storeNode(node)
if state_machine.complete and not self.keyscan_future:
self.log.debug("Submitting keyscan request")
self.updateNodeFromInstance(instance)
future = self.manager.keyscan_worker.submit(
@ -252,14 +253,16 @@ class StateMachineNodeDeleter:
if node.external_id:
state_machine.advance()
self.log.debug(f"State machine for {node.id} at "
"{state_machine.state}")
f"{state_machine.state}")
else:
self.state_machine.complete = True
if not self.state_machine.complete:
return
except exceptions.NotFound:
self.log.info(f"Instance {node.external_id} not found in "
"provider {node.provider}")
f"provider {node.provider}")
except Exception:
self.log.exception("Exception deleting instance "
f"{node.external_id} from {node.provider}:")
@ -423,15 +426,21 @@ class StateMachineProvider(Provider, QuotaSupport):
self.state_machine_thread.start()
def stop(self):
self.log.debug("Stopping")
self.running = False
if self.state_machine_thread:
while self.launchers or self.deleters:
time.sleep(1)
self.running = False
if self.keyscan_worker:
self.keyscan_worker.shutdown()
if self.state_machine_thread:
self.running = False
self.log.debug("Stopped")
def join(self):
self.log.debug("Joining")
if self.state_machine_thread:
self.state_machine_thread.join()
self.log.debug("Joined")
def _runStateMachines(self):
while self.running: