Merge "Azure: add quota support"

This commit is contained in:
Zuul 2021-06-18 21:47:24 +00:00 committed by Gerrit Code Review
commit d08f79d09e
5 changed files with 227 additions and 133 deletions

View File

@ -835,6 +835,9 @@ class ConfigValue(object, metaclass=abc.ABCMeta):
return True
return False
def __hash__(self):
return hash(frozenset(self.__dict__))
def __ne__(self, other):
return not self.__eq__(other)

View File

@ -23,14 +23,32 @@ from nodepool.driver import statemachine
from . import azul
def quota_info_from_sku(sku):
if not sku:
return QuotaInformation(instances=1)
cores = None
ram = None
for cap in sku['capabilities']:
if cap['name'] == 'vCPUs':
cores = int(cap['value'])
if cap['name'] == 'MemoryGB':
ram = int(float(cap['value']) * 1024)
return QuotaInformation(
cores=cores,
ram=ram,
instances=1)
class AzureInstance(statemachine.Instance):
def __init__(self, vm, nic=None, pip4=None, pip6=None):
def __init__(self, vm, nic=None, pip4=None, pip6=None, sku=None):
self.external_id = vm['name']
self.metadata = vm['tags'] or {}
self.private_ipv4 = None
self.private_ipv6 = None
self.public_ipv4 = None
self.public_ipv6 = None
self.sku = sku
if nic:
for ip_config_data in nic['properties']['ipConfigurations']:
@ -51,6 +69,9 @@ class AzureInstance(statemachine.Instance):
self.region = vm['location']
self.az = ''
def getQuotaInformation(self):
return quota_info_from_sku(self.sku)
class AzureDeleteStateMachine(statemachine.StateMachine):
VM_DELETING = 'deleting vm'
@ -255,6 +276,8 @@ class AzureAdapter(statemachine.Adapter):
net_info['network'],
net_info.get('subnet', 'default'))
self.subnet_id = subnet['id']
self.skus = {}
self._getSKUs()
def getCreateStateMachine(self, hostname, label, metadata, retries):
return AzureCreateStateMachine(
@ -290,13 +313,26 @@ class AzureAdapter(statemachine.Adapter):
def listInstances(self):
for vm in self._listVirtualMachines():
yield AzureInstance(vm)
sku = self.skus.get((vm['properties']['hardwareProfile']['vmSize'],
vm['location']))
yield AzureInstance(vm, sku=sku)
def getQuotaLimits(self):
return QuotaInformation(default=math.inf)
r = self.azul.compute_usages.list(self.provider.location)
cores = instances = math.inf
for item in r:
if item['name']['value'] == 'cores':
cores = item['limit']
elif item['name']['value'] == 'virtualMachines':
instances = item['limit']
return QuotaInformation(cores=cores,
instances=instances,
default=math.inf)
def getQuotaForLabel(self, label):
return QuotaInformation(instances=1)
sku = self.skus.get((label.hardware_profile["vm-size"],
self.provider.location))
return quota_info_from_sku(sku)
# Local implementation below
@ -344,6 +380,14 @@ class AzureAdapter(statemachine.Adapter):
return new_obj
return None
def _getSKUs(self):
self.log.debug("Querying compute SKUs")
for sku in self.azul.compute_skus.list():
for location in sku['locations']:
key = (sku['name'], location)
self.skus[key] = sku
self.log.debug("Done querying compute SKUs")
@cachetools.func.ttl_cache(maxsize=1, ttl=10)
def _listPublicIPAddresses(self):
return self.azul.public_ip_addresses.list(self.resource_group)

View File

@ -48,138 +48,147 @@ class AzureAuth(requests.auth.AuthBase):
class AzureError(Exception):
def __init__(self, status_code, message):
def __init__(self, status_code, error_code, message):
super().__init__(message)
self.error_code = error_code
self.status_code = status_code
class AzureNotFoundError(AzureError):
def __init__(self, status_code, message):
super().__init__(status_code, message)
class AzureResourceGroupsCRUD:
def __init__(self, cloud, version):
self.cloud = cloud
self.version = version
def url(self, url, **args):
base_url = (
'https://management.azure.com/subscriptions/{subscriptionId}'
'/resourcegroups/')
url = base_url + url + '?api-version={apiVersion}'
args = args.copy()
args.update(self.cloud.credential)
args['apiVersion'] = self.version
return url.format(**args)
def list(self):
url = self.url('')
return self.cloud.paginate(self.cloud.get(url))
def get(self, name):
url = self.url(name)
return self.cloud.get(url)
def create(self, name, params):
url = self.url(name)
return self.cloud.put(url, params)
def delete(self, name):
url = self.url(name)
return self.cloud.delete(url)
pass
class AzureCRUD:
def __init__(self, cloud, resource, version):
self.cloud = cloud
self.resource = resource
self.version = version
base_subscription_url = (
'https://management.azure.com/subscriptions/{subscriptionId}/')
base_url = ''
def url(self, url, **args):
base_url = (
'https://management.azure.com/subscriptions/{subscriptionId}'
'/resourceGroups/{resourceGroupName}/providers/')
url = base_url + url + '?api-version={apiVersion}'
args = args.copy()
args.update(self.cloud.credential)
args['apiVersion'] = self.version
def __init__(self, cloud, **kw):
self.cloud = cloud
self.args = kw.copy()
self.args.update(self.cloud.credential)
def url(self, **kw):
url = (self.base_subscription_url + self.base_url
+ '?api-version={apiVersion}')
args = self.args.copy()
args.update(kw)
return url.format(**args)
def id_url(self, url, **args):
def id_url(self, url, **kw):
base_url = 'https://management.azure.com'
url = base_url + url + '?api-version={apiVersion}'
args = args.copy()
args['apiVersion'] = self.version
args = self.args.copy()
args.update(kw)
return url.format(**args)
def list(self, resource_group_name):
url = self.url(
self.resource,
resourceGroupName=resource_group_name,
)
return self.cloud.paginate(self.cloud.get(url))
def get_by_id(self, resource_id):
url = self.id_url(resource_id)
return self.cloud.get(url)
def get(self, resource_group_name, name):
url = self.url(
'{}/{}'.format(self.resource, name),
resourceGroupName=resource_group_name,
)
return self.cloud.get(url)
def create(self, resource_group_name, name, params):
url = self.url(
'{}/{}'.format(self.resource, name),
resourceGroupName=resource_group_name,
)
return self.cloud.put(url, params)
def delete(self, resource_group_name, name):
url = self.url(
'{}/{}'.format(self.resource, name),
resourceGroupName=resource_group_name,
)
return self.cloud.delete(url)
class AzureSubnetCRUD(AzureCRUD):
def list(self, resource_group_name, virtual_network_name):
url = self.url(
self.resource,
resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
)
def _list(self, **kw):
url = self.url(**kw)
return self.cloud.paginate(self.cloud.get(url))
def get(self, resource_group_name, virtual_network_name, name):
url = self.url(
'{}/{}'.format(self.resource, name),
resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
)
def list(self):
return self._list()
def _get(self, **kw):
url = self.url(**kw)
return self.cloud.get(url)
def create(self, resource_group_name, virtual_network_name, name, params):
url = self.url(
'{}/{}'.format(self.resource, name),
resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
)
def _create(self, params, **kw):
url = self.url(**kw)
return self.cloud.put(url, params)
def delete(self, resource_group_name, virtual_network_name, name):
url = self.url(
'{}/{}'.format(self.resource, name),
resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
)
def _delete(self, **kw):
url = self.url(**kw)
return self.cloud.delete(url)
class AzureResourceGroupsCRUD(AzureCRUD):
base_url = 'resourcegroups/{resourceGroupName}'
def list(self):
return self._list(resourceGroupName='')
def get(self, name):
return self._get(resourceGroupName=name)
def create(self, name, params):
return self._create(params, resourceGroupName=name)
def delete(self, name):
return self._delete(resourceGroupName=name)
class AzureResourceProviderCRUD(AzureCRUD):
base_url = (
'/resourceGroups/{resourceGroupName}/providers/'
'{providerId}/{resource}/{resourceName}')
def list(self, resource_group_name):
return self._list(resourceGroupName=resource_group_name,
resourceName='')
def get(self, resource_group_name, name):
return self._get(resourceGroupName=resource_group_name,
resourceName=name)
def create(self, resource_group_name, name, params):
return self._create(params,
resourceGroupName=resource_group_name,
resourceName=name)
def delete(self, resource_group_name, name):
return self._delete(resourceGroupName=resource_group_name,
resourceName=name)
class AzureNetworkCRUD(AzureCRUD):
base_url = (
'/resourceGroups/{resourceGroupName}/providers/'
'Microsoft.Network/virtualNetworks/{virtualNetworkName}/'
'{resource}/{resourceName}')
def list(self, resource_group_name, virtual_network_name):
return self._list(resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
resourceName='')
def get(self, resource_group_name, virtual_network_name, name):
return self._get(resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
resourceName=name)
def create(self, resource_group_name, virtual_network_name, name, params):
return self._create(params,
resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
resourceName=name)
def delete(self, resource_group_name, virtual_network_name, name):
return self._delete(resourceGroupName=resource_group_name,
virtualNetworkName=virtual_network_name,
resourceName=name)
class AzureLocationCRUD(AzureCRUD):
base_url = (
'/providers/{providerId}/locations/{location}/{resource}')
def list(self, location):
return self._list(location=location)
class AzureProviderCRUD(AzureCRUD):
base_url = (
'/providers/{providerId}/{resource}/')
def list(self):
return self._list()
class AzureDictResponse(dict):
def __init__(self, response, *args):
super().__init__(*args)
@ -202,29 +211,43 @@ class AzureCloud:
self.session = requests.Session()
self.log = logging.getLogger("azul")
self.auth = AzureAuth(credential)
self.network_interfaces = AzureCRUD(
self.network_interfaces = AzureResourceProviderCRUD(
self,
'Microsoft.Network/networkInterfaces',
'2020-07-01')
self.public_ip_addresses = AzureCRUD(
providerId='Microsoft.Network',
resource='networkInterfaces',
apiVersion='2020-07-01')
self.public_ip_addresses = AzureResourceProviderCRUD(
self,
'Microsoft.Network/publicIPAddresses',
'2020-07-01')
self.virtual_machines = AzureCRUD(
providerId='Microsoft.Network',
resource='publicIPAddresses',
apiVersion='2020-07-01')
self.virtual_machines = AzureResourceProviderCRUD(
self,
'Microsoft.Compute/virtualMachines',
'2020-12-01')
self.disks = AzureCRUD(
providerId='Microsoft.Compute',
resource='virtualMachines',
apiVersion='2020-12-01')
self.disks = AzureResourceProviderCRUD(
self,
'Microsoft.Compute/disks',
'2020-06-30')
providerId='Microsoft.Compute',
resource='disks',
apiVersion='2020-06-30')
self.resource_groups = AzureResourceGroupsCRUD(
self,
'2020-06-01')
self.subnets = AzureSubnetCRUD(
apiVersion='2020-06-01')
self.subnets = AzureNetworkCRUD(
self,
'Microsoft.Network/virtualNetworks/{virtualNetworkName}/subnets',
'2020-07-01')
resource='subnets',
apiVersion='2020-07-01')
self.compute_usages = AzureLocationCRUD(
self,
providerId='Microsoft.Compute',
resource='usages',
apiVersion='2020-12-01')
self.compute_skus = AzureProviderCRUD(
self,
providerId='Microsoft.Compute',
resource='skus',
apiVersion='2019-04-01')
def get(self, url, codes=[200]):
return self.request('GET', url, None, codes)
@ -258,9 +281,13 @@ class AzureCloud:
self.log.error(response.text)
if response.status_code == 404:
raise AzureNotFoundError(
response.status_code, err['error']['message'])
response.status_code,
err['error']['code'],
err['error']['message'])
else:
raise AzureError(response.status_code, err['error']['message'])
raise AzureError(response.status_code,
err['error']['code'],
err['error']['message'])
def paginate(self, data):
ret = data['value']

View File

@ -27,7 +27,9 @@ from nodepool.logconfig import get_annotated_logger
from nodepool import stats
from nodepool import exceptions
from nodepool import zk
from kazoo import exceptions as kze
import cachetools
def keyscan(node_id, interface_ip,
@ -417,6 +419,9 @@ class StateMachineProvider(Provider, QuotaSupport):
self.keyscan_worker = None
self.state_machine_thread = None
self.running = False
num_labels = sum([len(pool.labels)
for pool in provider.pools.values()])
self.label_quota_cache = cachetools.LRUCache(num_labels)
def start(self, zk_conn):
super().start(zk_conn)
@ -484,10 +489,17 @@ class StateMachineProvider(Provider, QuotaSupport):
def quotaNeededByLabel(self, ntype, pool):
provider_label = pool.labels[ntype]
qi = self.label_quota_cache.get(provider_label)
if qi is not None:
return qi
try:
return self.adapter.getQuotaForLabel(provider_label)
qi = self.adapter.getQuotaForLabel(provider_label)
self.log.debug("Quota required for %s: %s",
provider_label.name, qi)
except NotImplementedError:
return QuotaInformation()
qi = QuotaInformation()
self.label_quota_cache.setdefault(provider_label, qi)
return qi
def unmanagedQuotaUsed(self):
'''

View File

@ -184,10 +184,11 @@ class QuotaInformation:
be initialized with default which will be typically 0 or math.inf
indicating an infinite limit.
:param cores:
:param instances:
:param ram:
:param default:
:param cores: An integer number of (v)CPU cores.
:param instances: An integer number of instances.
:param ram: An integer amount of RAM in Mebibytes.
:param default: The default value to use for any attribute not supplied
(usually 0 or math.inf).
'''
self.quota = {
'compute': {
@ -302,7 +303,14 @@ class QuotaSupport:
# This is initialized with the full tenant quota and later becomes
# the quota available for nodepool.
nodepool_quota = self.getProviderLimits()
try:
nodepool_quota = self.getProviderLimits()
except Exception:
if self._current_nodepool_quota:
self.log.exception("Unable to get provider quota, "
"using cached value")
return copy.deepcopy(self._current_nodepool_quota['quota'])
raise
self.log.debug("Provider quota for %s: %s",
self.provider.name, nodepool_quota)