mypy: Add type annotations to 'nova.pci'

The 'nova.pci' module is poorly understood and difficult to debug.
Start adding type hints to make this a little easier to parse and catch
dumb errors. Some code needs to be reworked to make 'mypy' happy, but
it's mostly just type annotations.

Note that because of how the 'nova.objects' module works, we need to
delay interpolation by using forward references, or expressing the type
as string literals to be resolved later [1].

[1] https://www.python.org/dev/peps/pep-0484/#forward-references

Change-Id: I2a609606806c6cabdf95d53339335f61743fc5b0
Signed-off-by: Stephen Finucane <sfinucan@redhat.com>
This commit is contained in:
Stephen Finucane 2021-04-22 11:50:53 +01:00
parent eba9d596da
commit 51d16adda6
9 changed files with 347 additions and 171 deletions

View File

@ -1,6 +1,7 @@
nova/compute/manager.py
nova/crypto.py
nova/network/neutron.py
nova/pci
nova/privsep/path.py
nova/scheduler/client/report.py
nova/scheduler/request_filter.py

View File

@ -14,9 +14,11 @@
import abc
import re
import string
import typing as ty
from nova import exception
from nova.i18n import _
from nova import objects
from nova.pci import utils
MAX_VENDOR_ID = 0xFFFF
@ -29,24 +31,35 @@ ANY = '*'
REGEX_ANY = '.*'
PCISpecAddressType = ty.Union[ty.Dict[str, str], str]
class PciAddressSpec(metaclass=abc.ABCMeta):
"""Abstract class for all PCI address spec styles
This class checks the address fields of the pci.passthrough_whitelist
"""
def __init__(self, pci_addr: str) -> None:
self.domain = ''
self.bus = ''
self.slot = ''
self.func = ''
@abc.abstractmethod
def match(self, pci_addr):
pass
def is_single_address(self):
def is_single_address(self) -> bool:
return all([
all(c in string.hexdigits for c in self.domain),
all(c in string.hexdigits for c in self.bus),
all(c in string.hexdigits for c in self.slot),
all(c in string.hexdigits for c in self.func)])
def _set_pci_dev_info(self, prop, maxval, hex_value):
def _set_pci_dev_info(
self, prop: str, maxval: int, hex_value: str
) -> None:
a = getattr(self, prop)
if a == ANY:
return
@ -70,8 +83,10 @@ class PhysicalPciAddress(PciAddressSpec):
This function class will validate the address fields for a single
PCI device.
"""
def __init__(self, pci_addr):
def __init__(self, pci_addr: PCISpecAddressType) -> None:
try:
# TODO(stephenfin): Is this ever actually a string?
if isinstance(pci_addr, dict):
self.domain = pci_addr['domain']
self.bus = pci_addr['bus']
@ -87,7 +102,7 @@ class PhysicalPciAddress(PciAddressSpec):
except (KeyError, ValueError):
raise exception.PciDeviceWrongAddressFormat(address=pci_addr)
def match(self, phys_pci_addr):
def match(self, phys_pci_addr: PciAddressSpec) -> bool:
conditions = [
self.domain == phys_pci_addr.domain,
self.bus == phys_pci_addr.bus,
@ -104,7 +119,7 @@ class PciAddressGlobSpec(PciAddressSpec):
check for wildcards, and insert wildcards where the field is left blank.
"""
def __init__(self, pci_addr):
def __init__(self, pci_addr: str) -> None:
self.domain = ANY
self.bus = ANY
self.slot = ANY
@ -129,7 +144,7 @@ class PciAddressGlobSpec(PciAddressSpec):
self._set_pci_dev_info('bus', MAX_BUS, '%02x')
self._set_pci_dev_info('slot', MAX_SLOT, '%02x')
def match(self, phys_pci_addr):
def match(self, phys_pci_addr: PciAddressSpec) -> bool:
conditions = [
self.domain in (ANY, phys_pci_addr.domain),
self.bus in (ANY, phys_pci_addr.bus),
@ -146,7 +161,8 @@ class PciAddressRegexSpec(PciAddressSpec):
The validation includes check for all PCI address attributes and validate
their regex.
"""
def __init__(self, pci_addr):
def __init__(self, pci_addr: dict) -> None:
try:
self.domain = pci_addr.get('domain', REGEX_ANY)
self.bus = pci_addr.get('bus', REGEX_ANY)
@ -159,7 +175,7 @@ class PciAddressRegexSpec(PciAddressSpec):
except re.error:
raise exception.PciDeviceWrongAddressFormat(address=pci_addr)
def match(self, phys_pci_addr):
def match(self, phys_pci_addr: PciAddressSpec) -> bool:
conditions = [
bool(self.domain_regex.match(phys_pci_addr.domain)),
bool(self.bus_regex.match(phys_pci_addr.bus)),
@ -187,11 +203,13 @@ class WhitelistPciAddress(object):
| passthrough_whitelist = {"vendor_id":"1137","product_id":"0071"}
"""
def __init__(self, pci_addr, is_physical_function):
def __init__(
self, pci_addr: PCISpecAddressType, is_physical_function: bool
) -> None:
self.is_physical_function = is_physical_function
self._init_address_fields(pci_addr)
def _check_physical_function(self):
def _check_physical_function(self) -> None:
if self.pci_address_spec.is_single_address():
self.is_physical_function = (
utils.is_physical_function(
@ -200,7 +218,8 @@ class WhitelistPciAddress(object):
self.pci_address_spec.slot,
self.pci_address_spec.func))
def _init_address_fields(self, pci_addr):
def _init_address_fields(self, pci_addr: PCISpecAddressType) -> None:
self.pci_address_spec: PciAddressSpec
if not self.is_physical_function:
if isinstance(pci_addr, str):
self.pci_address_spec = PciAddressGlobSpec(pci_addr)
@ -212,10 +231,12 @@ class WhitelistPciAddress(object):
else:
self.pci_address_spec = PhysicalPciAddress(pci_addr)
def match(self, pci_addr, pci_phys_addr):
"""Match a device to this PciAddress. Assume this is called given
pci_addr and pci_phys_addr reported by libvirt, no attempt is made to
verify if pci_addr is a VF of pci_phys_addr.
def match(self, pci_addr: str, pci_phys_addr: ty.Optional[str]) -> bool:
"""Match a device to this PciAddress.
Assume this is called with a ``pci_addr`` and ``pci_phys_addr``
reported by libvirt. No attempt is made to verify if ``pci_addr`` is a
VF of ``pci_phys_addr``.
:param pci_addr: PCI address of the device to match.
:param pci_phys_addr: PCI address of the parent of the device to match
@ -237,51 +258,57 @@ class WhitelistPciAddress(object):
class PciDeviceSpec(PciAddressSpec):
def __init__(self, dev_spec):
def __init__(self, dev_spec: ty.Dict[str, str]) -> None:
self.tags = dev_spec
self._init_dev_details()
def _init_dev_details(self):
def _init_dev_details(self) -> None:
self.vendor_id = self.tags.pop("vendor_id", ANY)
self.product_id = self.tags.pop("product_id", ANY)
self.dev_name = self.tags.pop("devname", None)
self.address: ty.Optional[WhitelistPciAddress] = None
# Note(moshele): The address attribute can be a string or a dict.
# For glob syntax or specific pci it is a string and for regex syntax
# it is a dict. The WhitelistPciAddress class handles both types.
self.address = self.tags.pop("address", None)
self.dev_name = self.tags.pop("devname", None)
address = self.tags.pop("address", None)
self.vendor_id = self.vendor_id.strip()
self._set_pci_dev_info('vendor_id', MAX_VENDOR_ID, '%04x')
self._set_pci_dev_info('product_id', MAX_PRODUCT_ID, '%04x')
if self.address and self.dev_name:
if address and self.dev_name:
raise exception.PciDeviceInvalidDeviceName()
if not self.dev_name:
pci_address = self.address or "*:*:*.*"
self.address = WhitelistPciAddress(pci_address, False)
def match(self, dev_dict):
if not self.dev_name:
self.address = WhitelistPciAddress(address or '*:*:*.*', False)
def match(self, dev_dict: ty.Dict[str, str]) -> bool:
address_obj: ty.Optional[WhitelistPciAddress]
if self.dev_name:
address_str, pf = utils.get_function_by_ifname(
self.dev_name)
address_str, pf = utils.get_function_by_ifname(self.dev_name)
if not address_str:
return False
# Note(moshele): In this case we always passing a string
# of the PF pci address
address_obj = WhitelistPciAddress(address_str, pf)
elif self.address:
else: # use self.address
address_obj = self.address
if not address_obj:
return False
return all([
self.vendor_id in (ANY, dev_dict['vendor_id']),
self.product_id in (ANY, dev_dict['product_id']),
address_obj.match(dev_dict['address'],
dev_dict.get('parent_addr'))])
def match_pci_obj(self, pci_obj):
def match_pci_obj(self, pci_obj: 'objects.PciDevice') -> bool:
return self.match({'vendor_id': pci_obj.vendor_id,
'product_id': pci_obj.product_id,
'address': pci_obj.address,
'parent_addr': pci_obj.parent_addr})
def get_tags(self):
def get_tags(self) -> ty.Dict[str, str]:
return self.tags

View File

@ -15,11 +15,13 @@
# under the License.
import collections
import typing as ty
from oslo_config import cfg
from oslo_log import log as logging
from oslo_serialization import jsonutils
from nova import context as ctx
from nova import exception
from nova import objects
from nova.objects import fields
@ -29,6 +31,9 @@ from nova.pci import whitelist
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
MappingType = ty.Dict[str, ty.List['objects.PciDevice']]
PCIInvType = ty.DefaultDict[str, ty.List['objects.PciDevice']]
class PciDevTracker(object):
"""Manage pci devices in a compute node.
@ -51,17 +56,19 @@ class PciDevTracker(object):
are saved.
"""
def __init__(self, context, compute_node):
def __init__(
self,
context: ctx.RequestContext,
compute_node: 'objects.ComputeNode',
):
"""Create a pci device tracker.
:param context: The request context.
:param compute_node: The object.ComputeNode whose PCI devices we're
tracking.
"""
super(PciDevTracker, self).__init__()
self.stale = {}
self.node_id = compute_node.id
self.stale: ty.Dict[str, objects.PciDevice] = {}
self.node_id: str = compute_node.id
self.dev_filter = whitelist.Whitelist(CONF.pci.passthrough_whitelist)
numa_topology = compute_node.numa_topology
if numa_topology:
@ -76,9 +83,10 @@ class PciDevTracker(object):
self._build_device_tree(self.pci_devs)
self._initial_instance_usage()
def _initial_instance_usage(self):
self.allocations = collections.defaultdict(list)
self.claims = collections.defaultdict(list)
def _initial_instance_usage(self) -> None:
self.allocations: PCIInvType = collections.defaultdict(list)
self.claims: PCIInvType = collections.defaultdict(list)
for dev in self.pci_devs:
uuid = dev.instance_uuid
if dev.status == fields.PciDeviceStatus.CLAIMED:
@ -88,7 +96,7 @@ class PciDevTracker(object):
elif dev.status == fields.PciDeviceStatus.AVAILABLE:
self.stats.add_device(dev)
def save(self, context):
def save(self, context: ctx.RequestContext) -> None:
for dev in self.pci_devs:
if dev.obj_what_changed():
with dev.obj_alternate_context(context):
@ -97,10 +105,12 @@ class PciDevTracker(object):
self.pci_devs.objects.remove(dev)
@property
def pci_stats(self):
def pci_stats(self) -> stats.PciDeviceStats:
return self.stats
def update_devices_from_hypervisor_resources(self, devices_json):
def update_devices_from_hypervisor_resources(
self, devices_json: str,
) -> None:
"""Sync the pci device tracker with hypervisor information.
To support pci device hot plug, we sync with the hypervisor
@ -159,7 +169,7 @@ class PciDevTracker(object):
self._set_hvdevs(devices)
@staticmethod
def _build_device_tree(all_devs):
def _build_device_tree(all_devs: ty.List['objects.PciDevice']) -> None:
"""Build a tree of devices that represents parent-child relationships.
We need to have the relationships set up so that we can easily make
@ -196,7 +206,7 @@ class PciDevTracker(object):
if dev.parent_device:
parents[dev.parent_addr].child_devices.append(dev)
def _set_hvdevs(self, devices):
def _set_hvdevs(self, devices: ty.List[ty.Dict[str, ty.Any]]) -> None:
exist_addrs = set([dev.address for dev in self.pci_devs])
new_addrs = set([dev['address'] for dev in devices])
@ -243,6 +253,7 @@ class PciDevTracker(object):
self.stats.remove_device(existed)
else:
# Update tracked devices.
new_value: ty.Dict[str, ty.Any]
new_value = next((dev for dev in devices if
dev['address'] == existed.address))
new_value['compute_node_id'] = self.node_id
@ -276,7 +287,12 @@ class PciDevTracker(object):
self._build_device_tree(self.pci_devs)
def _claim_instance(self, context, pci_requests, instance_numa_topology):
def _claim_instance(
self,
context: ctx.RequestContext,
pci_requests: 'objects.InstancePCIRequests',
instance_numa_topology: 'objects.InstanceNUMATopology',
) -> ty.List['objects.PciDevice']:
instance_cells = None
if instance_numa_topology:
instance_cells = instance_numa_topology.cells
@ -284,7 +300,7 @@ class PciDevTracker(object):
devs = self.stats.consume_requests(pci_requests.requests,
instance_cells)
if not devs:
return None
return []
instance_uuid = pci_requests.instance_uuid
for dev in devs:
@ -296,18 +312,15 @@ class PciDevTracker(object):
{'instance': instance_uuid})
return devs
def _allocate_instance(self, instance, devs):
for dev in devs:
dev.allocate(instance)
def claim_instance(
self,
context: ctx.RequestContext,
pci_requests: 'objects.InstancePCIRequests',
instance_numa_topology: 'objects.InstanceNUMATopology',
) -> ty.List['objects.PciDevice']:
def allocate_instance(self, instance):
devs = self.claims.pop(instance['uuid'], [])
self._allocate_instance(instance, devs)
if devs:
self.allocations[instance['uuid']] += devs
def claim_instance(self, context, pci_requests, instance_numa_topology):
devs = []
if self.pci_devs and pci_requests.requests:
instance_uuid = pci_requests.instance_uuid
devs = self._claim_instance(context, pci_requests,
@ -316,7 +329,21 @@ class PciDevTracker(object):
self.claims[instance_uuid] = devs
return devs
def free_device(self, dev, instance):
def _allocate_instance(
self, instance: 'objects.Instance', devs: ty.List['objects.PciDevice'],
) -> None:
for dev in devs:
dev.allocate(instance)
def allocate_instance(self, instance: 'objects.Instance') -> None:
devs = self.claims.pop(instance['uuid'], [])
self._allocate_instance(instance, devs)
if devs:
self.allocations[instance['uuid']] += devs
def free_device(
self, dev: 'objects.PciDevice', instance: 'objects.Instance'
) -> None:
"""Free device from pci resource tracker
:param dev: cloned pci device object that needs to be free
@ -335,7 +362,11 @@ class PciDevTracker(object):
break
def _remove_device_from_pci_mapping(
self, instance_uuid, pci_device, pci_mapping):
self,
instance_uuid: str,
pci_device: 'objects.PciDevice',
pci_mapping: MappingType,
) -> None:
"""Remove a PCI device from allocations or claims.
If there are no more PCI devices, pop the uuid.
@ -346,7 +377,9 @@ class PciDevTracker(object):
if len(pci_devices) == 0:
pci_mapping.pop(instance_uuid, None)
def _free_device(self, dev, instance=None):
def _free_device(
self, dev: 'objects.PciDevice', instance: 'objects.Instance' = None,
) -> None:
freed_devs = dev.free(instance)
stale = self.stale.pop(dev.address, None)
if stale:
@ -354,31 +387,41 @@ class PciDevTracker(object):
for dev in freed_devs:
self.stats.add_device(dev)
def free_instance_allocations(self, context, instance):
def free_instance_allocations(
self, context: ctx.RequestContext, instance: 'objects.Instance',
) -> None:
"""Free devices that are in ALLOCATED state for instance.
:param context: user request context (nova.context.RequestContext)
:param context: user request context
:param instance: instance object
"""
if self.allocations.pop(instance['uuid'], None):
for dev in self.pci_devs:
if (dev.status == fields.PciDeviceStatus.ALLOCATED and
dev.instance_uuid == instance['uuid']):
self._free_device(dev)
if not self.allocations.pop(instance['uuid'], None):
return
def free_instance_claims(self, context, instance):
for dev in self.pci_devs:
if (dev.status == fields.PciDeviceStatus.ALLOCATED and
dev.instance_uuid == instance['uuid']):
self._free_device(dev)
def free_instance_claims(
self, context: ctx.RequestContext, instance: 'objects.Instance',
) -> None:
"""Free devices that are in CLAIMED state for instance.
:param context: user request context (nova.context.RequestContext)
:param instance: instance object
"""
if self.claims.pop(instance['uuid'], None):
for dev in self.pci_devs:
if (dev.status == fields.PciDeviceStatus.CLAIMED and
dev.instance_uuid == instance['uuid']):
self._free_device(dev)
if not self.claims.pop(instance['uuid'], None):
return
def free_instance(self, context, instance):
for dev in self.pci_devs:
if (dev.status == fields.PciDeviceStatus.CLAIMED and
dev.instance_uuid == instance['uuid']):
self._free_device(dev)
def free_instance(
self, context: ctx.RequestContext, instance: 'objects.Instance',
) -> None:
"""Free devices that are in CLAIMED or ALLOCATED state for instance.
:param context: user request context (nova.context.RequestContext)
@ -392,9 +435,13 @@ class PciDevTracker(object):
self.free_instance_allocations(context, instance)
self.free_instance_claims(context, instance)
def update_pci_for_instance(self, context, instance, sign):
"""Update PCI usage information if devices are de/allocated.
"""
def update_pci_for_instance(
self,
context: ctx.RequestContext,
instance: 'objects.Instance',
sign: int,
) -> None:
"""Update PCI usage information if devices are de/allocated."""
if not self.pci_devs:
return
@ -403,7 +450,11 @@ class PciDevTracker(object):
if sign == 1:
self.allocate_instance(instance)
def clean_usage(self, instances, migrations):
def clean_usage(
self,
instances: 'objects.InstanceList',
migrations: 'objects.MigrationList',
) -> None:
"""Remove all usages for instances not passed in the parameter.
The caller should hold the COMPUTE_RESOURCE_SEMAPHORE lock
@ -425,7 +476,9 @@ class PciDevTracker(object):
self._free_device(dev)
def get_instance_pci_devs(inst, request_id=None):
def get_instance_pci_devs(
inst: 'objects.Instance', request_id: str = None,
) -> ty.List['objects.PciDevice']:
"""Get the devices allocated to one or all requests for an instance.
- For generic PCI request, the request id is None.
@ -437,5 +490,8 @@ def get_instance_pci_devs(inst, request_id=None):
pci_devices = inst.pci_devices
if pci_devices is None:
return []
return [device for device in pci_devices if
device.request_id == request_id or request_id == 'all']
return [
device for device in pci_devices if
device.request_id == request_id or request_id == 'all'
]

View File

@ -38,11 +38,14 @@
product_id is "0442" or "0443".
"""
import typing as ty
import jsonschema
from oslo_log import log as logging
from oslo_serialization import jsonutils
import nova.conf
from nova import context as ctx
from nova import exception
from nova.i18n import _
from nova.network import model as network_model
@ -50,7 +53,8 @@ from nova import objects
from nova.objects import fields as obj_fields
from nova.pci import utils
LOG = logging.getLogger(__name__)
Alias = ty.Dict[str, ty.Tuple[str, ty.List[ty.Dict[str, str]]]]
PCI_NET_TAG = 'physical_network'
PCI_TRUSTED_TAG = 'trusted'
PCI_DEVICE_TYPE_TAG = 'dev_type'
@ -61,6 +65,7 @@ DEVICE_TYPE_FOR_VNIC_TYPE = {
}
CONF = nova.conf.CONF
LOG = logging.getLogger(__name__)
_ALIAS_SCHEMA = {
"type": "object",
@ -104,18 +109,19 @@ _ALIAS_SCHEMA = {
}
def _get_alias_from_config():
def _get_alias_from_config() -> Alias:
"""Parse and validate PCI aliases from the nova config.
:returns: A dictionary where the keys are device names and the values are
tuples of form ``(specs, numa_policy)``. ``specs`` is a list of PCI
device specs, while ``numa_policy`` describes the required NUMA
affinity of the device(s).
tuples of form ``(numa_policy, specs)``. ``numa_policy`` describes the
required NUMA affinity of the device(s), while ``specs`` is a list of
PCI device specs.
:raises: exception.PciInvalidAlias if two aliases with the same name have
different device types or different NUMA policies.
"""
jaliases = CONF.pci.alias
aliases = {} # map alias name to alias spec list
# map alias name to alias spec list
aliases: Alias = {}
try:
for jsonspecs in jaliases:
spec = jsonutils.loads(jsonspecs)
@ -153,17 +159,18 @@ def _get_alias_from_config():
return aliases
def _translate_alias_to_requests(alias_spec, affinity_policy=None):
def _translate_alias_to_requests(
alias_spec: str, affinity_policy: str = None,
) -> ty.List['objects.InstancePCIRequest']:
"""Generate complete pci requests from pci aliases in extra_spec."""
pci_aliases = _get_alias_from_config()
pci_requests = []
pci_requests: ty.List[objects.InstancePCIRequest] = []
for name, count in [spec.split(':') for spec in alias_spec.split(',')]:
name = name.strip()
if name not in pci_aliases:
raise exception.PciRequestAliasNotDefined(alias=name)
count = int(count)
numa_policy, spec = pci_aliases[name]
policy = affinity_policy or numa_policy
@ -172,14 +179,18 @@ def _translate_alias_to_requests(alias_spec, affinity_policy=None):
# handling for InstancePCIRequests created from the flavor. So it is
# left empty.
pci_requests.append(objects.InstancePCIRequest(
count=count,
count=int(count),
spec=spec,
alias_name=name,
numa_policy=policy))
return pci_requests
def get_instance_pci_request_from_vif(context, instance, vif):
def get_instance_pci_request_from_vif(
context: ctx.RequestContext,
instance: 'objects.Instance',
vif: network_model.VIF,
) -> ty.Optional['objects.InstancePCIRequest']:
"""Given an Instance, return the PCI request associated
to the PCI device related to the given VIF (if any) on the
compute node the instance is currently running.
@ -233,7 +244,9 @@ def get_instance_pci_request_from_vif(context, instance, vif):
node_id=cn_id)
def get_pci_requests_from_flavor(flavor, affinity_policy=None):
def get_pci_requests_from_flavor(
flavor: 'objects.Flavor', affinity_policy: str = None,
) -> 'objects.InstancePCIRequests':
"""Validate and return PCI requests.
The ``pci_passthrough:alias`` extra spec describes the flavor's PCI
@ -279,7 +292,7 @@ def get_pci_requests_from_flavor(flavor, affinity_policy=None):
:raises: exception.PciInvalidAlias if the configuration contains invalid
aliases.
"""
pci_requests = []
pci_requests: ty.List[objects.InstancePCIRequest] = []
if ('extra_specs' in flavor and
'pci_passthrough:alias' in flavor['extra_specs']):
pci_requests = _translate_alias_to_requests(

View File

@ -14,20 +14,29 @@
# License for the specific language governing permissions and limitations
# under the License.
import copy
import typing as ty
from oslo_config import cfg
from oslo_log import log as logging
from nova import exception
from nova import objects
from nova.objects import fields
from nova.objects import pci_device_pool
from nova.pci import utils
from nova.pci import whitelist
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
# TODO(stephenfin): We might want to use TypedDict here. Refer to
# https://mypy.readthedocs.io/en/latest/kinds_of_types.html#typeddict for
# more information.
Pool = ty.Dict[str, ty.Any]
class PciDeviceStats(object):
"""PCI devices summary information.
@ -54,32 +63,42 @@ class PciDeviceStats(object):
pool_keys = ['product_id', 'vendor_id', 'numa_node', 'dev_type']
def __init__(self, numa_topology, stats=None, dev_filter=None):
super(PciDeviceStats, self).__init__()
def __init__(
self,
numa_topology: 'objects.NUMATopology',
stats: 'objects.PCIDevicePoolList' = None,
dev_filter: whitelist.Whitelist = None,
) -> None:
self.numa_topology = numa_topology
# NOTE(sbauza): Stats are a PCIDevicePoolList object
self.pools = [pci_pool.to_dict()
for pci_pool in stats] if stats else []
self.pools = (
[pci_pool.to_dict() for pci_pool in stats] if stats else []
)
self.pools.sort(key=lambda item: len(item))
self.dev_filter = dev_filter or whitelist.Whitelist(
CONF.pci.passthrough_whitelist)
def _equal_properties(self, dev, entry, matching_keys):
def _equal_properties(
self, dev: Pool, entry: Pool, matching_keys: ty.List[str],
) -> bool:
return all(dev.get(prop) == entry.get(prop)
for prop in matching_keys)
def _find_pool(self, dev_pool):
def _find_pool(self, dev_pool: Pool) -> ty.Optional[Pool]:
"""Return the first pool that matches dev."""
for pool in self.pools:
pool_keys = pool.copy()
del pool_keys['count']
del pool_keys['devices']
if (len(pool_keys.keys()) == len(dev_pool.keys()) and
self._equal_properties(dev_pool, pool_keys, dev_pool.keys())):
self._equal_properties(dev_pool, pool_keys, list(dev_pool))):
return pool
def _create_pool_keys_from_dev(self, dev):
"""create a stats pool dict that this dev is supposed to be part of
return None
def _create_pool_keys_from_dev(
self, dev: 'objects.PciDevice',
) -> ty.Optional[Pool]:
"""Create a stats pool dict that this dev is supposed to be part of
Note that this pool dict contains the stats pool's keys and their
values. 'count' and 'devices' are not included.
@ -88,7 +107,7 @@ class PciDeviceStats(object):
# This can happen during initial sync up with the controller
devspec = self.dev_filter.get_devspec(dev)
if not devspec:
return
return None
tags = devspec.get_tags()
pool = {k: getattr(dev, k) for k in self.pool_keys}
if tags:
@ -103,7 +122,9 @@ class PciDeviceStats(object):
pool['parent_ifname'] = dev.extra_info['parent_ifname']
return pool
def _get_pool_with_device_type_mismatch(self, dev):
def _get_pool_with_device_type_mismatch(
self, dev: 'objects.PciDevice',
) -> ty.Optional[ty.Tuple[Pool, 'objects.PciDevice']]:
"""Check for device type mismatch in the pools for a given device.
Return (pool, device) if device type does not match or a single None
@ -118,18 +139,18 @@ class PciDeviceStats(object):
return None
def update_device(self, dev):
def update_device(self, dev: 'objects.PciDevice') -> None:
"""Update a device to its matching pool."""
pool_device_info = self._get_pool_with_device_type_mismatch(dev)
if pool_device_info is None:
return
return None
pool, device = pool_device_info
pool['devices'].remove(device)
self._decrease_pool_count(self.pools, pool)
self.add_device(dev)
def add_device(self, dev):
def add_device(self, dev: 'objects.PciDevice') -> None:
"""Add a device to its matching pool."""
dev_pool = self._create_pool_keys_from_dev(dev)
if dev_pool:
@ -144,7 +165,9 @@ class PciDeviceStats(object):
pool['devices'].append(dev)
@staticmethod
def _decrease_pool_count(pool_list, pool, count=1):
def _decrease_pool_count(
pool_list: ty.List[Pool], pool: Pool, count: int = 1,
) -> int:
"""Decrement pool's size by count.
If pool becomes empty, remove pool from pool_list.
@ -157,7 +180,7 @@ class PciDeviceStats(object):
pool_list.remove(pool)
return count
def remove_device(self, dev):
def remove_device(self, dev: 'objects.PciDevice') -> None:
"""Remove one device from the first pool that it matches."""
dev_pool = self._create_pool_keys_from_dev(dev)
if dev_pool:
@ -168,14 +191,20 @@ class PciDeviceStats(object):
pool['devices'].remove(dev)
self._decrease_pool_count(self.pools, pool)
def get_free_devs(self):
free_devs = []
def get_free_devs(self) -> ty.List['objects.PciDevice']:
free_devs: ty.List[objects.PciDevice] = []
for pool in self.pools:
free_devs.extend(pool['devices'])
return free_devs
def consume_requests(self, pci_requests, numa_cells=None):
alloc_devices = []
def consume_requests(
self,
pci_requests: 'objects.InstancePCIRequests',
numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None,
) -> ty.Optional[ty.List['objects.PciDevice']]:
alloc_devices: ty.List[objects.PciDevice] = []
for request in pci_requests:
count = request.count
@ -212,7 +241,7 @@ class PciDeviceStats(object):
return alloc_devices
def _handle_device_dependents(self, pci_dev):
def _handle_device_dependents(self, pci_dev: 'objects.PciDevice') -> None:
"""Remove device dependents or a parent from pools.
In case the device is a PF, all of it's dependent VFs should
@ -238,7 +267,9 @@ class PciDeviceStats(object):
except exception.PciDeviceNotFound:
return
def _filter_pools_for_spec(self, pools, request):
def _filter_pools_for_spec(
self, pools: ty.List[Pool], request: 'objects.InstancePCIRequest',
) -> ty.List[Pool]:
"""Filter out pools that don't match the request's device spec.
Exclude pools that do not match the specified ``vendor_id``,
@ -257,7 +288,12 @@ class PciDeviceStats(object):
if utils.pci_device_prop_match(pool, request_specs)
]
def _filter_pools_for_numa_cells(self, pools, request, numa_cells):
def _filter_pools_for_numa_cells(
self,
pools: ty.List[Pool],
request: 'objects.InstancePCIRequest',
numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']],
) -> ty.List[Pool]:
"""Filter out pools with the wrong NUMA affinity, if required.
Exclude pools that do not have *suitable* PCI NUMA affinity.
@ -335,7 +371,11 @@ class PciDeviceStats(object):
return sorted(
pools, key=lambda pool: pool.get('numa_node') not in numa_cell_ids)
def _filter_pools_for_socket_affinity(self, pools, numa_cells):
def _filter_pools_for_socket_affinity(
self,
pools: ty.List[Pool],
numa_cells: ty.List['objects.InstanceNUMACell'],
) -> ty.List[Pool]:
host_cells = self.numa_topology.cells
# bail early if we don't have socket information for all host_cells.
# This could happen if we're running on an weird older system with
@ -368,7 +408,9 @@ class PciDeviceStats(object):
)
]
def _filter_pools_for_unrequested_pfs(self, pools, request):
def _filter_pools_for_unrequested_pfs(
self, pools: ty.List[Pool], request: 'objects.InstancePCIRequest',
) -> ty.List[Pool]:
"""Filter out pools with PFs, unless these are required.
This is necessary in cases where PFs and VFs have the same product_id
@ -390,7 +432,11 @@ class PciDeviceStats(object):
]
return pools
def _filter_pools_for_unrequested_vdpa_devices(self, pools, request):
def _filter_pools_for_unrequested_vdpa_devices(
self,
pools: ty.List[Pool],
request: 'objects.InstancePCIRequest',
) -> ty.List[Pool]:
"""Filter out pools with VDPA devices, unless these are required.
This is necessary as vdpa devices require special handling and
@ -412,7 +458,12 @@ class PciDeviceStats(object):
]
return pools
def _filter_pools(self, pools, request, numa_cells):
def _filter_pools(
self,
pools: ty.List[Pool],
request: 'objects.InstancePCIRequest',
numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']],
) -> ty.Optional[ty.List[Pool]]:
"""Determine if an individual PCI request can be met.
Filter pools, which are collections of devices with similar traits, to
@ -502,7 +553,11 @@ class PciDeviceStats(object):
return pools
def support_requests(self, requests, numa_cells=None):
def support_requests(
self,
requests: ty.List['objects.InstancePCIRequest'],
numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None,
) -> bool:
"""Determine if the PCI requests can be met.
Determine, based on a compute node's PCI stats, if an instance can be
@ -524,7 +579,12 @@ class PciDeviceStats(object):
self._filter_pools(self.pools, r, numa_cells) for r in requests
)
def _apply_request(self, pools, request, numa_cells=None):
def _apply_request(
self,
pools: ty.List[Pool],
request: 'objects.InstancePCIRequest',
numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None,
) -> bool:
"""Apply an individual PCI request.
Apply a PCI request against a given set of PCI device pools, which are
@ -558,7 +618,11 @@ class PciDeviceStats(object):
return True
def apply_requests(self, requests, numa_cells=None):
def apply_requests(
self,
requests: ty.List['objects.InstancePCIRequest'],
numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None,
) -> None:
"""Apply PCI requests to the PCI stats.
This is used in multiple instance creation, when the scheduler has to
@ -580,22 +644,26 @@ class PciDeviceStats(object):
):
raise exception.PciDeviceRequestFailed(requests=requests)
def __iter__(self):
# 'devices' shouldn't be part of stats
pools = []
def __iter__(self) -> ty.Iterator[Pool]:
pools: ty.List[Pool] = []
for pool in self.pools:
tmp = {k: v for k, v in pool.items() if k != 'devices'}
pools.append(tmp)
pool = copy.deepcopy(pool)
# 'devices' shouldn't be part of stats
if 'devices' in pool:
del pool['devices']
pools.append(pool)
return iter(pools)
def clear(self):
def clear(self) -> None:
"""Clear all the stats maintained."""
self.pools = []
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, PciDeviceStats):
return NotImplemented
return self.pools == other.pools
def to_device_pools_obj(self):
def to_device_pools_obj(self) -> 'objects.PciDevicePoolList':
"""Return the contents of the pools as a PciDevicePoolList object."""
stats = [x for x in self]
return pci_device_pool.from_pci_stats(stats)

View File

@ -14,15 +14,19 @@
# License for the specific language governing permissions and limitations
# under the License.
import glob
import os
import re
import typing as ty
from oslo_log import log as logging
from nova import exception
if ty.TYPE_CHECKING:
# avoid circular import
from nova.pci import stats
LOG = logging.getLogger(__name__)
PCI_VENDOR_PATTERN = "^(hex{4})$".replace("hex", r"[\da-fA-F]")
@ -30,11 +34,12 @@ _PCI_ADDRESS_PATTERN = ("^(hex{4}):(hex{2}):(hex{2}).(oct{1})$".
replace("hex", r"[\da-fA-F]").
replace("oct", "[0-7]"))
_PCI_ADDRESS_REGEX = re.compile(_PCI_ADDRESS_PATTERN)
_SRIOV_TOTALVFS = "sriov_totalvfs"
def pci_device_prop_match(pci_dev, specs):
def pci_device_prop_match(
pci_dev: 'stats.Pool', specs: ty.List[ty.Dict[str, str]],
) -> bool:
"""Check if the pci_dev meet spec requirement
Specs is a list of PCI device property requirements.
@ -47,7 +52,8 @@ def pci_device_prop_match(pci_dev, specs):
"capabilities_network": ["rx", "tx", "tso", "gso"]}]
"""
def _matching_devices(spec):
def _matching_devices(spec: ty.Dict[str, str]) -> bool:
for k, v in spec.items():
pci_dev_v = pci_dev.get(k)
if isinstance(v, list) and isinstance(pci_dev_v, list):
@ -69,8 +75,10 @@ def pci_device_prop_match(pci_dev, specs):
return any(_matching_devices(spec) for spec in specs)
def parse_address(address):
"""Returns (domain, bus, slot, function) from PCI address that is stored in
def parse_address(address: str) -> ty.Sequence[str]:
"""Parse a PCI address.
Returns (domain, bus, slot, function) from PCI address that is stored in
PciDevice DB table.
"""
m = _PCI_ADDRESS_REGEX.match(address)
@ -79,7 +87,7 @@ def parse_address(address):
return m.groups()
def get_pci_address_fields(pci_addr):
def get_pci_address_fields(pci_addr: str) -> ty.Tuple[str, str, str, str]:
"""Parse a fully-specified PCI device address.
Does not validate that the components are valid hex or wildcard values.
@ -92,7 +100,7 @@ def get_pci_address_fields(pci_addr):
return domain, bus, slot, func
def get_pci_address(domain, bus, slot, func):
def get_pci_address(domain: str, bus: str, slot: str, func: str) -> str:
"""Assembles PCI address components into a fully-specified PCI address.
Does not validate that the components are valid hex or wildcard values.
@ -103,7 +111,7 @@ def get_pci_address(domain, bus, slot, func):
return '%s:%s:%s.%s' % (domain, bus, slot, func)
def get_function_by_ifname(ifname):
def get_function_by_ifname(ifname: str) -> ty.Tuple[ty.Optional[str], bool]:
"""Given the device name, returns the PCI address of a device
and returns True if the address is in a physical function.
"""
@ -121,7 +129,9 @@ def get_function_by_ifname(ifname):
return None, False
def is_physical_function(domain, bus, slot, function):
def is_physical_function(
domain: str, bus: str, slot: str, function: str,
) -> bool:
dev_path = "/sys/bus/pci/devices/%(d)s:%(b)s:%(s)s.%(f)s/" % {
"d": domain, "b": bus, "s": slot, "f": function}
if os.path.isdir(dev_path):
@ -134,7 +144,7 @@ def is_physical_function(domain, bus, slot, function):
return False
def _get_sysfs_netdev_path(pci_addr, pf_interface):
def _get_sysfs_netdev_path(pci_addr: str, pf_interface: bool) -> str:
"""Get the sysfs path based on the PCI address of the device.
Assumes a networking device - will not check for the existence of the path.
@ -144,7 +154,9 @@ def _get_sysfs_netdev_path(pci_addr, pf_interface):
return "/sys/bus/pci/devices/%s/net" % pci_addr
def get_ifname_by_pci_address(pci_addr, pf_interface=False):
def get_ifname_by_pci_address(
pci_addr: str, pf_interface: bool = False,
) -> str:
"""Get the interface name based on a VF's pci address.
The returned interface name is either the parent PF's or that of the VF
@ -158,7 +170,7 @@ def get_ifname_by_pci_address(pci_addr, pf_interface=False):
raise exception.PciDeviceNotFoundById(id=pci_addr)
def get_mac_by_pci_address(pci_addr, pf_interface=False):
def get_mac_by_pci_address(pci_addr: str, pf_interface: bool = False) -> str:
"""Get the MAC address of the nic based on its PCI address.
Raises PciDeviceNotFoundById in case the pci device is not a NIC
@ -179,7 +191,7 @@ def get_mac_by_pci_address(pci_addr, pf_interface=False):
raise exception.PciDeviceNotFoundById(id=pci_addr)
def get_vf_num_by_pci_address(pci_addr):
def get_vf_num_by_pci_address(pci_addr: str) -> str:
"""Get the VF number based on a VF's pci address
A VF is associated with an VF number, which ip link command uses to
@ -188,14 +200,14 @@ def get_vf_num_by_pci_address(pci_addr):
VIRTFN_RE = re.compile(r"virtfn(\d+)")
virtfns_path = "/sys/bus/pci/devices/%s/physfn/virtfn*" % (pci_addr)
vf_num = None
try:
for vf_path in glob.iglob(virtfns_path):
if re.search(pci_addr, os.readlink(vf_path)):
t = VIRTFN_RE.search(vf_path)
for vf_path in glob.iglob(virtfns_path):
if re.search(pci_addr, os.readlink(vf_path)):
t = VIRTFN_RE.search(vf_path)
if t:
vf_num = t.group(1)
break
except Exception:
pass
if vf_num is None:
else:
raise exception.PciDeviceNotFoundById(id=pci_addr)
return vf_num

View File

@ -14,10 +14,13 @@
# License for the specific language governing permissions and limitations
# under the License.
import typing as ty
from oslo_serialization import jsonutils
from nova import exception
from nova.i18n import _
from nova import objects
from nova.pci import devspec
@ -30,7 +33,7 @@ class Whitelist(object):
assignable.
"""
def __init__(self, whitelist_spec=None):
def __init__(self, whitelist_spec: str = None) -> None:
"""White list constructor
For example, the following json string specifies that devices whose
@ -50,7 +53,9 @@ class Whitelist(object):
self.specs = []
@staticmethod
def _parse_white_list_from_config(whitelists):
def _parse_white_list_from_config(
whitelists: str,
) -> ty.List[devspec.PciDeviceSpec]:
"""Parse and validate the pci whitelist from the nova config."""
specs = []
for jsonspec in whitelists:
@ -77,7 +82,7 @@ class Whitelist(object):
return specs
def device_assignable(self, dev):
def device_assignable(self, dev: ty.Dict[str, str]) -> bool:
"""Check if a device can be assigned to a guest.
:param dev: A dictionary describing the device properties
@ -87,7 +92,11 @@ class Whitelist(object):
return True
return False
def get_devspec(self, pci_dev):
def get_devspec(
self, pci_dev: 'objects.PciDevice',
) -> ty.Optional[devspec.PciDeviceSpec]:
for spec in self.specs:
if spec.match_pci_obj(pci_dev):
return spec
return None

View File

@ -448,10 +448,11 @@ class PciDevTrackerTestCase(test.NoDBTestCase):
self.inst.numa_topology = objects.InstanceNUMATopology(
cells=[objects.InstanceNUMACell(
id=1, cpuset=set([1, 2]), memory=512)])
self.assertIsNone(self.tracker.claim_instance(
mock.sentinel.context,
pci_requests_obj,
self.inst.numa_topology))
claims = self.tracker.claim_instance(
mock.sentinel.context,
pci_requests_obj,
self.inst.numa_topology)
self.assertEqual([], claims)
def test_update_pci_for_instance_deleted(self):
pci_requests_obj = self._create_pci_requests_object(fake_pci_requests)

View File

@ -251,14 +251,3 @@ class GetVfNumByPciAddressTestCase(test.NoDBTestCase):
utils.get_vf_num_by_pci_address,
self.pci_address
)
@mock.patch.object(os, 'readlink')
@mock.patch.object(glob, 'iglob')
def test_exception(self, mock_iglob, mock_readlink):
mock_iglob.return_value = self.paths
mock_readlink.side_effect = OSError('No such file or directory')
self.assertRaises(
exception.PciDeviceNotFoundById,
utils.get_vf_num_by_pci_address,
self.pci_address
)