diff --git a/mypy-files.txt b/mypy-files.txt index fad22f4c76ff..898eee25c7cd 100644 --- a/mypy-files.txt +++ b/mypy-files.txt @@ -1,6 +1,7 @@ nova/compute/manager.py nova/crypto.py nova/network/neutron.py +nova/pci nova/privsep/path.py nova/scheduler/client/report.py nova/scheduler/request_filter.py diff --git a/nova/pci/devspec.py b/nova/pci/devspec.py index d76a25ce6a37..52934252693d 100644 --- a/nova/pci/devspec.py +++ b/nova/pci/devspec.py @@ -14,9 +14,11 @@ import abc import re import string +import typing as ty from nova import exception from nova.i18n import _ +from nova import objects from nova.pci import utils MAX_VENDOR_ID = 0xFFFF @@ -29,24 +31,35 @@ ANY = '*' REGEX_ANY = '.*' +PCISpecAddressType = ty.Union[ty.Dict[str, str], str] + + class PciAddressSpec(metaclass=abc.ABCMeta): """Abstract class for all PCI address spec styles This class checks the address fields of the pci.passthrough_whitelist """ + def __init__(self, pci_addr: str) -> None: + self.domain = '' + self.bus = '' + self.slot = '' + self.func = '' + @abc.abstractmethod def match(self, pci_addr): pass - def is_single_address(self): + def is_single_address(self) -> bool: return all([ all(c in string.hexdigits for c in self.domain), all(c in string.hexdigits for c in self.bus), all(c in string.hexdigits for c in self.slot), all(c in string.hexdigits for c in self.func)]) - def _set_pci_dev_info(self, prop, maxval, hex_value): + def _set_pci_dev_info( + self, prop: str, maxval: int, hex_value: str + ) -> None: a = getattr(self, prop) if a == ANY: return @@ -70,8 +83,10 @@ class PhysicalPciAddress(PciAddressSpec): This function class will validate the address fields for a single PCI device. """ - def __init__(self, pci_addr): + + def __init__(self, pci_addr: PCISpecAddressType) -> None: try: + # TODO(stephenfin): Is this ever actually a string? if isinstance(pci_addr, dict): self.domain = pci_addr['domain'] self.bus = pci_addr['bus'] @@ -87,7 +102,7 @@ class PhysicalPciAddress(PciAddressSpec): except (KeyError, ValueError): raise exception.PciDeviceWrongAddressFormat(address=pci_addr) - def match(self, phys_pci_addr): + def match(self, phys_pci_addr: PciAddressSpec) -> bool: conditions = [ self.domain == phys_pci_addr.domain, self.bus == phys_pci_addr.bus, @@ -104,7 +119,7 @@ class PciAddressGlobSpec(PciAddressSpec): check for wildcards, and insert wildcards where the field is left blank. """ - def __init__(self, pci_addr): + def __init__(self, pci_addr: str) -> None: self.domain = ANY self.bus = ANY self.slot = ANY @@ -129,7 +144,7 @@ class PciAddressGlobSpec(PciAddressSpec): self._set_pci_dev_info('bus', MAX_BUS, '%02x') self._set_pci_dev_info('slot', MAX_SLOT, '%02x') - def match(self, phys_pci_addr): + def match(self, phys_pci_addr: PciAddressSpec) -> bool: conditions = [ self.domain in (ANY, phys_pci_addr.domain), self.bus in (ANY, phys_pci_addr.bus), @@ -146,7 +161,8 @@ class PciAddressRegexSpec(PciAddressSpec): The validation includes check for all PCI address attributes and validate their regex. """ - def __init__(self, pci_addr): + + def __init__(self, pci_addr: dict) -> None: try: self.domain = pci_addr.get('domain', REGEX_ANY) self.bus = pci_addr.get('bus', REGEX_ANY) @@ -159,7 +175,7 @@ class PciAddressRegexSpec(PciAddressSpec): except re.error: raise exception.PciDeviceWrongAddressFormat(address=pci_addr) - def match(self, phys_pci_addr): + def match(self, phys_pci_addr: PciAddressSpec) -> bool: conditions = [ bool(self.domain_regex.match(phys_pci_addr.domain)), bool(self.bus_regex.match(phys_pci_addr.bus)), @@ -187,11 +203,13 @@ class WhitelistPciAddress(object): | passthrough_whitelist = {"vendor_id":"1137","product_id":"0071"} """ - def __init__(self, pci_addr, is_physical_function): + def __init__( + self, pci_addr: PCISpecAddressType, is_physical_function: bool + ) -> None: self.is_physical_function = is_physical_function self._init_address_fields(pci_addr) - def _check_physical_function(self): + def _check_physical_function(self) -> None: if self.pci_address_spec.is_single_address(): self.is_physical_function = ( utils.is_physical_function( @@ -200,7 +218,8 @@ class WhitelistPciAddress(object): self.pci_address_spec.slot, self.pci_address_spec.func)) - def _init_address_fields(self, pci_addr): + def _init_address_fields(self, pci_addr: PCISpecAddressType) -> None: + self.pci_address_spec: PciAddressSpec if not self.is_physical_function: if isinstance(pci_addr, str): self.pci_address_spec = PciAddressGlobSpec(pci_addr) @@ -212,10 +231,12 @@ class WhitelistPciAddress(object): else: self.pci_address_spec = PhysicalPciAddress(pci_addr) - def match(self, pci_addr, pci_phys_addr): - """Match a device to this PciAddress. Assume this is called given - pci_addr and pci_phys_addr reported by libvirt, no attempt is made to - verify if pci_addr is a VF of pci_phys_addr. + def match(self, pci_addr: str, pci_phys_addr: ty.Optional[str]) -> bool: + """Match a device to this PciAddress. + + Assume this is called with a ``pci_addr`` and ``pci_phys_addr`` + reported by libvirt. No attempt is made to verify if ``pci_addr`` is a + VF of ``pci_phys_addr``. :param pci_addr: PCI address of the device to match. :param pci_phys_addr: PCI address of the parent of the device to match @@ -237,51 +258,57 @@ class WhitelistPciAddress(object): class PciDeviceSpec(PciAddressSpec): - def __init__(self, dev_spec): + def __init__(self, dev_spec: ty.Dict[str, str]) -> None: self.tags = dev_spec self._init_dev_details() - def _init_dev_details(self): + def _init_dev_details(self) -> None: self.vendor_id = self.tags.pop("vendor_id", ANY) self.product_id = self.tags.pop("product_id", ANY) + self.dev_name = self.tags.pop("devname", None) + self.address: ty.Optional[WhitelistPciAddress] = None # Note(moshele): The address attribute can be a string or a dict. # For glob syntax or specific pci it is a string and for regex syntax # it is a dict. The WhitelistPciAddress class handles both types. - self.address = self.tags.pop("address", None) - self.dev_name = self.tags.pop("devname", None) + address = self.tags.pop("address", None) self.vendor_id = self.vendor_id.strip() self._set_pci_dev_info('vendor_id', MAX_VENDOR_ID, '%04x') self._set_pci_dev_info('product_id', MAX_PRODUCT_ID, '%04x') - if self.address and self.dev_name: + if address and self.dev_name: raise exception.PciDeviceInvalidDeviceName() - if not self.dev_name: - pci_address = self.address or "*:*:*.*" - self.address = WhitelistPciAddress(pci_address, False) - def match(self, dev_dict): + if not self.dev_name: + self.address = WhitelistPciAddress(address or '*:*:*.*', False) + + def match(self, dev_dict: ty.Dict[str, str]) -> bool: + address_obj: ty.Optional[WhitelistPciAddress] + if self.dev_name: - address_str, pf = utils.get_function_by_ifname( - self.dev_name) + address_str, pf = utils.get_function_by_ifname(self.dev_name) if not address_str: return False # Note(moshele): In this case we always passing a string # of the PF pci address address_obj = WhitelistPciAddress(address_str, pf) - elif self.address: + else: # use self.address address_obj = self.address + + if not address_obj: + return False + return all([ self.vendor_id in (ANY, dev_dict['vendor_id']), self.product_id in (ANY, dev_dict['product_id']), address_obj.match(dev_dict['address'], dev_dict.get('parent_addr'))]) - def match_pci_obj(self, pci_obj): + def match_pci_obj(self, pci_obj: 'objects.PciDevice') -> bool: return self.match({'vendor_id': pci_obj.vendor_id, 'product_id': pci_obj.product_id, 'address': pci_obj.address, 'parent_addr': pci_obj.parent_addr}) - def get_tags(self): + def get_tags(self) -> ty.Dict[str, str]: return self.tags diff --git a/nova/pci/manager.py b/nova/pci/manager.py index b124b994d845..fc6a8417246d 100644 --- a/nova/pci/manager.py +++ b/nova/pci/manager.py @@ -15,11 +15,13 @@ # under the License. import collections +import typing as ty from oslo_config import cfg from oslo_log import log as logging from oslo_serialization import jsonutils +from nova import context as ctx from nova import exception from nova import objects from nova.objects import fields @@ -29,6 +31,9 @@ from nova.pci import whitelist CONF = cfg.CONF LOG = logging.getLogger(__name__) +MappingType = ty.Dict[str, ty.List['objects.PciDevice']] +PCIInvType = ty.DefaultDict[str, ty.List['objects.PciDevice']] + class PciDevTracker(object): """Manage pci devices in a compute node. @@ -51,17 +56,19 @@ class PciDevTracker(object): are saved. """ - def __init__(self, context, compute_node): + def __init__( + self, + context: ctx.RequestContext, + compute_node: 'objects.ComputeNode', + ): """Create a pci device tracker. :param context: The request context. :param compute_node: The object.ComputeNode whose PCI devices we're tracking. """ - - super(PciDevTracker, self).__init__() - self.stale = {} - self.node_id = compute_node.id + self.stale: ty.Dict[str, objects.PciDevice] = {} + self.node_id: str = compute_node.id self.dev_filter = whitelist.Whitelist(CONF.pci.passthrough_whitelist) numa_topology = compute_node.numa_topology if numa_topology: @@ -76,9 +83,10 @@ class PciDevTracker(object): self._build_device_tree(self.pci_devs) self._initial_instance_usage() - def _initial_instance_usage(self): - self.allocations = collections.defaultdict(list) - self.claims = collections.defaultdict(list) + def _initial_instance_usage(self) -> None: + self.allocations: PCIInvType = collections.defaultdict(list) + self.claims: PCIInvType = collections.defaultdict(list) + for dev in self.pci_devs: uuid = dev.instance_uuid if dev.status == fields.PciDeviceStatus.CLAIMED: @@ -88,7 +96,7 @@ class PciDevTracker(object): elif dev.status == fields.PciDeviceStatus.AVAILABLE: self.stats.add_device(dev) - def save(self, context): + def save(self, context: ctx.RequestContext) -> None: for dev in self.pci_devs: if dev.obj_what_changed(): with dev.obj_alternate_context(context): @@ -97,10 +105,12 @@ class PciDevTracker(object): self.pci_devs.objects.remove(dev) @property - def pci_stats(self): + def pci_stats(self) -> stats.PciDeviceStats: return self.stats - def update_devices_from_hypervisor_resources(self, devices_json): + def update_devices_from_hypervisor_resources( + self, devices_json: str, + ) -> None: """Sync the pci device tracker with hypervisor information. To support pci device hot plug, we sync with the hypervisor @@ -159,7 +169,7 @@ class PciDevTracker(object): self._set_hvdevs(devices) @staticmethod - def _build_device_tree(all_devs): + def _build_device_tree(all_devs: ty.List['objects.PciDevice']) -> None: """Build a tree of devices that represents parent-child relationships. We need to have the relationships set up so that we can easily make @@ -196,7 +206,7 @@ class PciDevTracker(object): if dev.parent_device: parents[dev.parent_addr].child_devices.append(dev) - def _set_hvdevs(self, devices): + def _set_hvdevs(self, devices: ty.List[ty.Dict[str, ty.Any]]) -> None: exist_addrs = set([dev.address for dev in self.pci_devs]) new_addrs = set([dev['address'] for dev in devices]) @@ -243,6 +253,7 @@ class PciDevTracker(object): self.stats.remove_device(existed) else: # Update tracked devices. + new_value: ty.Dict[str, ty.Any] new_value = next((dev for dev in devices if dev['address'] == existed.address)) new_value['compute_node_id'] = self.node_id @@ -276,7 +287,12 @@ class PciDevTracker(object): self._build_device_tree(self.pci_devs) - def _claim_instance(self, context, pci_requests, instance_numa_topology): + def _claim_instance( + self, + context: ctx.RequestContext, + pci_requests: 'objects.InstancePCIRequests', + instance_numa_topology: 'objects.InstanceNUMATopology', + ) -> ty.List['objects.PciDevice']: instance_cells = None if instance_numa_topology: instance_cells = instance_numa_topology.cells @@ -284,7 +300,7 @@ class PciDevTracker(object): devs = self.stats.consume_requests(pci_requests.requests, instance_cells) if not devs: - return None + return [] instance_uuid = pci_requests.instance_uuid for dev in devs: @@ -296,18 +312,15 @@ class PciDevTracker(object): {'instance': instance_uuid}) return devs - def _allocate_instance(self, instance, devs): - for dev in devs: - dev.allocate(instance) + def claim_instance( + self, + context: ctx.RequestContext, + pci_requests: 'objects.InstancePCIRequests', + instance_numa_topology: 'objects.InstanceNUMATopology', + ) -> ty.List['objects.PciDevice']: - def allocate_instance(self, instance): - devs = self.claims.pop(instance['uuid'], []) - self._allocate_instance(instance, devs) - if devs: - self.allocations[instance['uuid']] += devs - - def claim_instance(self, context, pci_requests, instance_numa_topology): devs = [] + if self.pci_devs and pci_requests.requests: instance_uuid = pci_requests.instance_uuid devs = self._claim_instance(context, pci_requests, @@ -316,7 +329,21 @@ class PciDevTracker(object): self.claims[instance_uuid] = devs return devs - def free_device(self, dev, instance): + def _allocate_instance( + self, instance: 'objects.Instance', devs: ty.List['objects.PciDevice'], + ) -> None: + for dev in devs: + dev.allocate(instance) + + def allocate_instance(self, instance: 'objects.Instance') -> None: + devs = self.claims.pop(instance['uuid'], []) + self._allocate_instance(instance, devs) + if devs: + self.allocations[instance['uuid']] += devs + + def free_device( + self, dev: 'objects.PciDevice', instance: 'objects.Instance' + ) -> None: """Free device from pci resource tracker :param dev: cloned pci device object that needs to be free @@ -335,7 +362,11 @@ class PciDevTracker(object): break def _remove_device_from_pci_mapping( - self, instance_uuid, pci_device, pci_mapping): + self, + instance_uuid: str, + pci_device: 'objects.PciDevice', + pci_mapping: MappingType, + ) -> None: """Remove a PCI device from allocations or claims. If there are no more PCI devices, pop the uuid. @@ -346,7 +377,9 @@ class PciDevTracker(object): if len(pci_devices) == 0: pci_mapping.pop(instance_uuid, None) - def _free_device(self, dev, instance=None): + def _free_device( + self, dev: 'objects.PciDevice', instance: 'objects.Instance' = None, + ) -> None: freed_devs = dev.free(instance) stale = self.stale.pop(dev.address, None) if stale: @@ -354,31 +387,41 @@ class PciDevTracker(object): for dev in freed_devs: self.stats.add_device(dev) - def free_instance_allocations(self, context, instance): + def free_instance_allocations( + self, context: ctx.RequestContext, instance: 'objects.Instance', + ) -> None: """Free devices that are in ALLOCATED state for instance. - :param context: user request context (nova.context.RequestContext) + :param context: user request context :param instance: instance object """ - if self.allocations.pop(instance['uuid'], None): - for dev in self.pci_devs: - if (dev.status == fields.PciDeviceStatus.ALLOCATED and - dev.instance_uuid == instance['uuid']): - self._free_device(dev) + if not self.allocations.pop(instance['uuid'], None): + return - def free_instance_claims(self, context, instance): + for dev in self.pci_devs: + if (dev.status == fields.PciDeviceStatus.ALLOCATED and + dev.instance_uuid == instance['uuid']): + self._free_device(dev) + + def free_instance_claims( + self, context: ctx.RequestContext, instance: 'objects.Instance', + ) -> None: """Free devices that are in CLAIMED state for instance. :param context: user request context (nova.context.RequestContext) :param instance: instance object """ - if self.claims.pop(instance['uuid'], None): - for dev in self.pci_devs: - if (dev.status == fields.PciDeviceStatus.CLAIMED and - dev.instance_uuid == instance['uuid']): - self._free_device(dev) + if not self.claims.pop(instance['uuid'], None): + return - def free_instance(self, context, instance): + for dev in self.pci_devs: + if (dev.status == fields.PciDeviceStatus.CLAIMED and + dev.instance_uuid == instance['uuid']): + self._free_device(dev) + + def free_instance( + self, context: ctx.RequestContext, instance: 'objects.Instance', + ) -> None: """Free devices that are in CLAIMED or ALLOCATED state for instance. :param context: user request context (nova.context.RequestContext) @@ -392,9 +435,13 @@ class PciDevTracker(object): self.free_instance_allocations(context, instance) self.free_instance_claims(context, instance) - def update_pci_for_instance(self, context, instance, sign): - """Update PCI usage information if devices are de/allocated. - """ + def update_pci_for_instance( + self, + context: ctx.RequestContext, + instance: 'objects.Instance', + sign: int, + ) -> None: + """Update PCI usage information if devices are de/allocated.""" if not self.pci_devs: return @@ -403,7 +450,11 @@ class PciDevTracker(object): if sign == 1: self.allocate_instance(instance) - def clean_usage(self, instances, migrations): + def clean_usage( + self, + instances: 'objects.InstanceList', + migrations: 'objects.MigrationList', + ) -> None: """Remove all usages for instances not passed in the parameter. The caller should hold the COMPUTE_RESOURCE_SEMAPHORE lock @@ -425,7 +476,9 @@ class PciDevTracker(object): self._free_device(dev) -def get_instance_pci_devs(inst, request_id=None): +def get_instance_pci_devs( + inst: 'objects.Instance', request_id: str = None, +) -> ty.List['objects.PciDevice']: """Get the devices allocated to one or all requests for an instance. - For generic PCI request, the request id is None. @@ -437,5 +490,8 @@ def get_instance_pci_devs(inst, request_id=None): pci_devices = inst.pci_devices if pci_devices is None: return [] - return [device for device in pci_devices if - device.request_id == request_id or request_id == 'all'] + + return [ + device for device in pci_devices if + device.request_id == request_id or request_id == 'all' + ] diff --git a/nova/pci/request.py b/nova/pci/request.py index 1924925c5530..01ea1ae11288 100644 --- a/nova/pci/request.py +++ b/nova/pci/request.py @@ -38,11 +38,14 @@ product_id is "0442" or "0443". """ +import typing as ty + import jsonschema from oslo_log import log as logging from oslo_serialization import jsonutils import nova.conf +from nova import context as ctx from nova import exception from nova.i18n import _ from nova.network import model as network_model @@ -50,7 +53,8 @@ from nova import objects from nova.objects import fields as obj_fields from nova.pci import utils -LOG = logging.getLogger(__name__) +Alias = ty.Dict[str, ty.Tuple[str, ty.List[ty.Dict[str, str]]]] + PCI_NET_TAG = 'physical_network' PCI_TRUSTED_TAG = 'trusted' PCI_DEVICE_TYPE_TAG = 'dev_type' @@ -61,6 +65,7 @@ DEVICE_TYPE_FOR_VNIC_TYPE = { } CONF = nova.conf.CONF +LOG = logging.getLogger(__name__) _ALIAS_SCHEMA = { "type": "object", @@ -104,18 +109,19 @@ _ALIAS_SCHEMA = { } -def _get_alias_from_config(): +def _get_alias_from_config() -> Alias: """Parse and validate PCI aliases from the nova config. :returns: A dictionary where the keys are device names and the values are - tuples of form ``(specs, numa_policy)``. ``specs`` is a list of PCI - device specs, while ``numa_policy`` describes the required NUMA - affinity of the device(s). + tuples of form ``(numa_policy, specs)``. ``numa_policy`` describes the + required NUMA affinity of the device(s), while ``specs`` is a list of + PCI device specs. :raises: exception.PciInvalidAlias if two aliases with the same name have different device types or different NUMA policies. """ jaliases = CONF.pci.alias - aliases = {} # map alias name to alias spec list + # map alias name to alias spec list + aliases: Alias = {} try: for jsonspecs in jaliases: spec = jsonutils.loads(jsonspecs) @@ -153,17 +159,18 @@ def _get_alias_from_config(): return aliases -def _translate_alias_to_requests(alias_spec, affinity_policy=None): +def _translate_alias_to_requests( + alias_spec: str, affinity_policy: str = None, +) -> ty.List['objects.InstancePCIRequest']: """Generate complete pci requests from pci aliases in extra_spec.""" pci_aliases = _get_alias_from_config() - pci_requests = [] + pci_requests: ty.List[objects.InstancePCIRequest] = [] for name, count in [spec.split(':') for spec in alias_spec.split(',')]: name = name.strip() if name not in pci_aliases: raise exception.PciRequestAliasNotDefined(alias=name) - count = int(count) numa_policy, spec = pci_aliases[name] policy = affinity_policy or numa_policy @@ -172,14 +179,18 @@ def _translate_alias_to_requests(alias_spec, affinity_policy=None): # handling for InstancePCIRequests created from the flavor. So it is # left empty. pci_requests.append(objects.InstancePCIRequest( - count=count, + count=int(count), spec=spec, alias_name=name, numa_policy=policy)) return pci_requests -def get_instance_pci_request_from_vif(context, instance, vif): +def get_instance_pci_request_from_vif( + context: ctx.RequestContext, + instance: 'objects.Instance', + vif: network_model.VIF, +) -> ty.Optional['objects.InstancePCIRequest']: """Given an Instance, return the PCI request associated to the PCI device related to the given VIF (if any) on the compute node the instance is currently running. @@ -233,7 +244,9 @@ def get_instance_pci_request_from_vif(context, instance, vif): node_id=cn_id) -def get_pci_requests_from_flavor(flavor, affinity_policy=None): +def get_pci_requests_from_flavor( + flavor: 'objects.Flavor', affinity_policy: str = None, +) -> 'objects.InstancePCIRequests': """Validate and return PCI requests. The ``pci_passthrough:alias`` extra spec describes the flavor's PCI @@ -279,7 +292,7 @@ def get_pci_requests_from_flavor(flavor, affinity_policy=None): :raises: exception.PciInvalidAlias if the configuration contains invalid aliases. """ - pci_requests = [] + pci_requests: ty.List[objects.InstancePCIRequest] = [] if ('extra_specs' in flavor and 'pci_passthrough:alias' in flavor['extra_specs']): pci_requests = _translate_alias_to_requests( diff --git a/nova/pci/stats.py b/nova/pci/stats.py index 56a3b02a0724..e8e810fa4f92 100644 --- a/nova/pci/stats.py +++ b/nova/pci/stats.py @@ -14,20 +14,29 @@ # License for the specific language governing permissions and limitations # under the License. +import copy +import typing as ty + from oslo_config import cfg from oslo_log import log as logging from nova import exception +from nova import objects from nova.objects import fields from nova.objects import pci_device_pool from nova.pci import utils from nova.pci import whitelist - CONF = cfg.CONF LOG = logging.getLogger(__name__) +# TODO(stephenfin): We might want to use TypedDict here. Refer to +# https://mypy.readthedocs.io/en/latest/kinds_of_types.html#typeddict for +# more information. +Pool = ty.Dict[str, ty.Any] + + class PciDeviceStats(object): """PCI devices summary information. @@ -54,32 +63,42 @@ class PciDeviceStats(object): pool_keys = ['product_id', 'vendor_id', 'numa_node', 'dev_type'] - def __init__(self, numa_topology, stats=None, dev_filter=None): - super(PciDeviceStats, self).__init__() + def __init__( + self, + numa_topology: 'objects.NUMATopology', + stats: 'objects.PCIDevicePoolList' = None, + dev_filter: whitelist.Whitelist = None, + ) -> None: self.numa_topology = numa_topology - # NOTE(sbauza): Stats are a PCIDevicePoolList object - self.pools = [pci_pool.to_dict() - for pci_pool in stats] if stats else [] + self.pools = ( + [pci_pool.to_dict() for pci_pool in stats] if stats else [] + ) self.pools.sort(key=lambda item: len(item)) self.dev_filter = dev_filter or whitelist.Whitelist( CONF.pci.passthrough_whitelist) - def _equal_properties(self, dev, entry, matching_keys): + def _equal_properties( + self, dev: Pool, entry: Pool, matching_keys: ty.List[str], + ) -> bool: return all(dev.get(prop) == entry.get(prop) for prop in matching_keys) - def _find_pool(self, dev_pool): + def _find_pool(self, dev_pool: Pool) -> ty.Optional[Pool]: """Return the first pool that matches dev.""" for pool in self.pools: pool_keys = pool.copy() del pool_keys['count'] del pool_keys['devices'] if (len(pool_keys.keys()) == len(dev_pool.keys()) and - self._equal_properties(dev_pool, pool_keys, dev_pool.keys())): + self._equal_properties(dev_pool, pool_keys, list(dev_pool))): return pool - def _create_pool_keys_from_dev(self, dev): - """create a stats pool dict that this dev is supposed to be part of + return None + + def _create_pool_keys_from_dev( + self, dev: 'objects.PciDevice', + ) -> ty.Optional[Pool]: + """Create a stats pool dict that this dev is supposed to be part of Note that this pool dict contains the stats pool's keys and their values. 'count' and 'devices' are not included. @@ -88,7 +107,7 @@ class PciDeviceStats(object): # This can happen during initial sync up with the controller devspec = self.dev_filter.get_devspec(dev) if not devspec: - return + return None tags = devspec.get_tags() pool = {k: getattr(dev, k) for k in self.pool_keys} if tags: @@ -103,7 +122,9 @@ class PciDeviceStats(object): pool['parent_ifname'] = dev.extra_info['parent_ifname'] return pool - def _get_pool_with_device_type_mismatch(self, dev): + def _get_pool_with_device_type_mismatch( + self, dev: 'objects.PciDevice', + ) -> ty.Optional[ty.Tuple[Pool, 'objects.PciDevice']]: """Check for device type mismatch in the pools for a given device. Return (pool, device) if device type does not match or a single None @@ -118,18 +139,18 @@ class PciDeviceStats(object): return None - def update_device(self, dev): + def update_device(self, dev: 'objects.PciDevice') -> None: """Update a device to its matching pool.""" pool_device_info = self._get_pool_with_device_type_mismatch(dev) if pool_device_info is None: - return + return None pool, device = pool_device_info pool['devices'].remove(device) self._decrease_pool_count(self.pools, pool) self.add_device(dev) - def add_device(self, dev): + def add_device(self, dev: 'objects.PciDevice') -> None: """Add a device to its matching pool.""" dev_pool = self._create_pool_keys_from_dev(dev) if dev_pool: @@ -144,7 +165,9 @@ class PciDeviceStats(object): pool['devices'].append(dev) @staticmethod - def _decrease_pool_count(pool_list, pool, count=1): + def _decrease_pool_count( + pool_list: ty.List[Pool], pool: Pool, count: int = 1, + ) -> int: """Decrement pool's size by count. If pool becomes empty, remove pool from pool_list. @@ -157,7 +180,7 @@ class PciDeviceStats(object): pool_list.remove(pool) return count - def remove_device(self, dev): + def remove_device(self, dev: 'objects.PciDevice') -> None: """Remove one device from the first pool that it matches.""" dev_pool = self._create_pool_keys_from_dev(dev) if dev_pool: @@ -168,14 +191,20 @@ class PciDeviceStats(object): pool['devices'].remove(dev) self._decrease_pool_count(self.pools, pool) - def get_free_devs(self): - free_devs = [] + def get_free_devs(self) -> ty.List['objects.PciDevice']: + free_devs: ty.List[objects.PciDevice] = [] for pool in self.pools: free_devs.extend(pool['devices']) return free_devs - def consume_requests(self, pci_requests, numa_cells=None): - alloc_devices = [] + def consume_requests( + self, + pci_requests: 'objects.InstancePCIRequests', + numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None, + ) -> ty.Optional[ty.List['objects.PciDevice']]: + + alloc_devices: ty.List[objects.PciDevice] = [] + for request in pci_requests: count = request.count @@ -212,7 +241,7 @@ class PciDeviceStats(object): return alloc_devices - def _handle_device_dependents(self, pci_dev): + def _handle_device_dependents(self, pci_dev: 'objects.PciDevice') -> None: """Remove device dependents or a parent from pools. In case the device is a PF, all of it's dependent VFs should @@ -238,7 +267,9 @@ class PciDeviceStats(object): except exception.PciDeviceNotFound: return - def _filter_pools_for_spec(self, pools, request): + def _filter_pools_for_spec( + self, pools: ty.List[Pool], request: 'objects.InstancePCIRequest', + ) -> ty.List[Pool]: """Filter out pools that don't match the request's device spec. Exclude pools that do not match the specified ``vendor_id``, @@ -257,7 +288,12 @@ class PciDeviceStats(object): if utils.pci_device_prop_match(pool, request_specs) ] - def _filter_pools_for_numa_cells(self, pools, request, numa_cells): + def _filter_pools_for_numa_cells( + self, + pools: ty.List[Pool], + request: 'objects.InstancePCIRequest', + numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']], + ) -> ty.List[Pool]: """Filter out pools with the wrong NUMA affinity, if required. Exclude pools that do not have *suitable* PCI NUMA affinity. @@ -335,7 +371,11 @@ class PciDeviceStats(object): return sorted( pools, key=lambda pool: pool.get('numa_node') not in numa_cell_ids) - def _filter_pools_for_socket_affinity(self, pools, numa_cells): + def _filter_pools_for_socket_affinity( + self, + pools: ty.List[Pool], + numa_cells: ty.List['objects.InstanceNUMACell'], + ) -> ty.List[Pool]: host_cells = self.numa_topology.cells # bail early if we don't have socket information for all host_cells. # This could happen if we're running on an weird older system with @@ -368,7 +408,9 @@ class PciDeviceStats(object): ) ] - def _filter_pools_for_unrequested_pfs(self, pools, request): + def _filter_pools_for_unrequested_pfs( + self, pools: ty.List[Pool], request: 'objects.InstancePCIRequest', + ) -> ty.List[Pool]: """Filter out pools with PFs, unless these are required. This is necessary in cases where PFs and VFs have the same product_id @@ -390,7 +432,11 @@ class PciDeviceStats(object): ] return pools - def _filter_pools_for_unrequested_vdpa_devices(self, pools, request): + def _filter_pools_for_unrequested_vdpa_devices( + self, + pools: ty.List[Pool], + request: 'objects.InstancePCIRequest', + ) -> ty.List[Pool]: """Filter out pools with VDPA devices, unless these are required. This is necessary as vdpa devices require special handling and @@ -412,7 +458,12 @@ class PciDeviceStats(object): ] return pools - def _filter_pools(self, pools, request, numa_cells): + def _filter_pools( + self, + pools: ty.List[Pool], + request: 'objects.InstancePCIRequest', + numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']], + ) -> ty.Optional[ty.List[Pool]]: """Determine if an individual PCI request can be met. Filter pools, which are collections of devices with similar traits, to @@ -502,7 +553,11 @@ class PciDeviceStats(object): return pools - def support_requests(self, requests, numa_cells=None): + def support_requests( + self, + requests: ty.List['objects.InstancePCIRequest'], + numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None, + ) -> bool: """Determine if the PCI requests can be met. Determine, based on a compute node's PCI stats, if an instance can be @@ -524,7 +579,12 @@ class PciDeviceStats(object): self._filter_pools(self.pools, r, numa_cells) for r in requests ) - def _apply_request(self, pools, request, numa_cells=None): + def _apply_request( + self, + pools: ty.List[Pool], + request: 'objects.InstancePCIRequest', + numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None, + ) -> bool: """Apply an individual PCI request. Apply a PCI request against a given set of PCI device pools, which are @@ -558,7 +618,11 @@ class PciDeviceStats(object): return True - def apply_requests(self, requests, numa_cells=None): + def apply_requests( + self, + requests: ty.List['objects.InstancePCIRequest'], + numa_cells: ty.Optional[ty.List['objects.InstanceNUMACell']] = None, + ) -> None: """Apply PCI requests to the PCI stats. This is used in multiple instance creation, when the scheduler has to @@ -580,22 +644,26 @@ class PciDeviceStats(object): ): raise exception.PciDeviceRequestFailed(requests=requests) - def __iter__(self): - # 'devices' shouldn't be part of stats - pools = [] + def __iter__(self) -> ty.Iterator[Pool]: + pools: ty.List[Pool] = [] for pool in self.pools: - tmp = {k: v for k, v in pool.items() if k != 'devices'} - pools.append(tmp) + pool = copy.deepcopy(pool) + # 'devices' shouldn't be part of stats + if 'devices' in pool: + del pool['devices'] + pools.append(pool) return iter(pools) - def clear(self): + def clear(self) -> None: """Clear all the stats maintained.""" self.pools = [] - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, PciDeviceStats): + return NotImplemented return self.pools == other.pools - def to_device_pools_obj(self): + def to_device_pools_obj(self) -> 'objects.PciDevicePoolList': """Return the contents of the pools as a PciDevicePoolList object.""" stats = [x for x in self] return pci_device_pool.from_pci_stats(stats) diff --git a/nova/pci/utils.py b/nova/pci/utils.py index ed0384b056a6..778cdb227c41 100644 --- a/nova/pci/utils.py +++ b/nova/pci/utils.py @@ -14,15 +14,19 @@ # License for the specific language governing permissions and limitations # under the License. - import glob import os import re +import typing as ty from oslo_log import log as logging from nova import exception +if ty.TYPE_CHECKING: + # avoid circular import + from nova.pci import stats + LOG = logging.getLogger(__name__) PCI_VENDOR_PATTERN = "^(hex{4})$".replace("hex", r"[\da-fA-F]") @@ -30,11 +34,12 @@ _PCI_ADDRESS_PATTERN = ("^(hex{4}):(hex{2}):(hex{2}).(oct{1})$". replace("hex", r"[\da-fA-F]"). replace("oct", "[0-7]")) _PCI_ADDRESS_REGEX = re.compile(_PCI_ADDRESS_PATTERN) - _SRIOV_TOTALVFS = "sriov_totalvfs" -def pci_device_prop_match(pci_dev, specs): +def pci_device_prop_match( + pci_dev: 'stats.Pool', specs: ty.List[ty.Dict[str, str]], +) -> bool: """Check if the pci_dev meet spec requirement Specs is a list of PCI device property requirements. @@ -47,7 +52,8 @@ def pci_device_prop_match(pci_dev, specs): "capabilities_network": ["rx", "tx", "tso", "gso"]}] """ - def _matching_devices(spec): + + def _matching_devices(spec: ty.Dict[str, str]) -> bool: for k, v in spec.items(): pci_dev_v = pci_dev.get(k) if isinstance(v, list) and isinstance(pci_dev_v, list): @@ -69,8 +75,10 @@ def pci_device_prop_match(pci_dev, specs): return any(_matching_devices(spec) for spec in specs) -def parse_address(address): - """Returns (domain, bus, slot, function) from PCI address that is stored in +def parse_address(address: str) -> ty.Sequence[str]: + """Parse a PCI address. + + Returns (domain, bus, slot, function) from PCI address that is stored in PciDevice DB table. """ m = _PCI_ADDRESS_REGEX.match(address) @@ -79,7 +87,7 @@ def parse_address(address): return m.groups() -def get_pci_address_fields(pci_addr): +def get_pci_address_fields(pci_addr: str) -> ty.Tuple[str, str, str, str]: """Parse a fully-specified PCI device address. Does not validate that the components are valid hex or wildcard values. @@ -92,7 +100,7 @@ def get_pci_address_fields(pci_addr): return domain, bus, slot, func -def get_pci_address(domain, bus, slot, func): +def get_pci_address(domain: str, bus: str, slot: str, func: str) -> str: """Assembles PCI address components into a fully-specified PCI address. Does not validate that the components are valid hex or wildcard values. @@ -103,7 +111,7 @@ def get_pci_address(domain, bus, slot, func): return '%s:%s:%s.%s' % (domain, bus, slot, func) -def get_function_by_ifname(ifname): +def get_function_by_ifname(ifname: str) -> ty.Tuple[ty.Optional[str], bool]: """Given the device name, returns the PCI address of a device and returns True if the address is in a physical function. """ @@ -121,7 +129,9 @@ def get_function_by_ifname(ifname): return None, False -def is_physical_function(domain, bus, slot, function): +def is_physical_function( + domain: str, bus: str, slot: str, function: str, +) -> bool: dev_path = "/sys/bus/pci/devices/%(d)s:%(b)s:%(s)s.%(f)s/" % { "d": domain, "b": bus, "s": slot, "f": function} if os.path.isdir(dev_path): @@ -134,7 +144,7 @@ def is_physical_function(domain, bus, slot, function): return False -def _get_sysfs_netdev_path(pci_addr, pf_interface): +def _get_sysfs_netdev_path(pci_addr: str, pf_interface: bool) -> str: """Get the sysfs path based on the PCI address of the device. Assumes a networking device - will not check for the existence of the path. @@ -144,7 +154,9 @@ def _get_sysfs_netdev_path(pci_addr, pf_interface): return "/sys/bus/pci/devices/%s/net" % pci_addr -def get_ifname_by_pci_address(pci_addr, pf_interface=False): +def get_ifname_by_pci_address( + pci_addr: str, pf_interface: bool = False, +) -> str: """Get the interface name based on a VF's pci address. The returned interface name is either the parent PF's or that of the VF @@ -158,7 +170,7 @@ def get_ifname_by_pci_address(pci_addr, pf_interface=False): raise exception.PciDeviceNotFoundById(id=pci_addr) -def get_mac_by_pci_address(pci_addr, pf_interface=False): +def get_mac_by_pci_address(pci_addr: str, pf_interface: bool = False) -> str: """Get the MAC address of the nic based on its PCI address. Raises PciDeviceNotFoundById in case the pci device is not a NIC @@ -179,7 +191,7 @@ def get_mac_by_pci_address(pci_addr, pf_interface=False): raise exception.PciDeviceNotFoundById(id=pci_addr) -def get_vf_num_by_pci_address(pci_addr): +def get_vf_num_by_pci_address(pci_addr: str) -> str: """Get the VF number based on a VF's pci address A VF is associated with an VF number, which ip link command uses to @@ -188,14 +200,14 @@ def get_vf_num_by_pci_address(pci_addr): VIRTFN_RE = re.compile(r"virtfn(\d+)") virtfns_path = "/sys/bus/pci/devices/%s/physfn/virtfn*" % (pci_addr) vf_num = None - try: - for vf_path in glob.iglob(virtfns_path): - if re.search(pci_addr, os.readlink(vf_path)): - t = VIRTFN_RE.search(vf_path) + + for vf_path in glob.iglob(virtfns_path): + if re.search(pci_addr, os.readlink(vf_path)): + t = VIRTFN_RE.search(vf_path) + if t: vf_num = t.group(1) break - except Exception: - pass - if vf_num is None: + else: raise exception.PciDeviceNotFoundById(id=pci_addr) + return vf_num diff --git a/nova/pci/whitelist.py b/nova/pci/whitelist.py index 3c9984094d00..7623f4903e50 100644 --- a/nova/pci/whitelist.py +++ b/nova/pci/whitelist.py @@ -14,10 +14,13 @@ # License for the specific language governing permissions and limitations # under the License. +import typing as ty + from oslo_serialization import jsonutils from nova import exception from nova.i18n import _ +from nova import objects from nova.pci import devspec @@ -30,7 +33,7 @@ class Whitelist(object): assignable. """ - def __init__(self, whitelist_spec=None): + def __init__(self, whitelist_spec: str = None) -> None: """White list constructor For example, the following json string specifies that devices whose @@ -50,7 +53,9 @@ class Whitelist(object): self.specs = [] @staticmethod - def _parse_white_list_from_config(whitelists): + def _parse_white_list_from_config( + whitelists: str, + ) -> ty.List[devspec.PciDeviceSpec]: """Parse and validate the pci whitelist from the nova config.""" specs = [] for jsonspec in whitelists: @@ -77,7 +82,7 @@ class Whitelist(object): return specs - def device_assignable(self, dev): + def device_assignable(self, dev: ty.Dict[str, str]) -> bool: """Check if a device can be assigned to a guest. :param dev: A dictionary describing the device properties @@ -87,7 +92,11 @@ class Whitelist(object): return True return False - def get_devspec(self, pci_dev): + def get_devspec( + self, pci_dev: 'objects.PciDevice', + ) -> ty.Optional[devspec.PciDeviceSpec]: for spec in self.specs: if spec.match_pci_obj(pci_dev): return spec + + return None diff --git a/nova/tests/unit/pci/test_manager.py b/nova/tests/unit/pci/test_manager.py index 3bc5ae65ea38..7cacf52dd2cb 100644 --- a/nova/tests/unit/pci/test_manager.py +++ b/nova/tests/unit/pci/test_manager.py @@ -448,10 +448,11 @@ class PciDevTrackerTestCase(test.NoDBTestCase): self.inst.numa_topology = objects.InstanceNUMATopology( cells=[objects.InstanceNUMACell( id=1, cpuset=set([1, 2]), memory=512)]) - self.assertIsNone(self.tracker.claim_instance( - mock.sentinel.context, - pci_requests_obj, - self.inst.numa_topology)) + claims = self.tracker.claim_instance( + mock.sentinel.context, + pci_requests_obj, + self.inst.numa_topology) + self.assertEqual([], claims) def test_update_pci_for_instance_deleted(self): pci_requests_obj = self._create_pci_requests_object(fake_pci_requests) diff --git a/nova/tests/unit/pci/test_utils.py b/nova/tests/unit/pci/test_utils.py index 500a36d13896..e444f137299e 100644 --- a/nova/tests/unit/pci/test_utils.py +++ b/nova/tests/unit/pci/test_utils.py @@ -251,14 +251,3 @@ class GetVfNumByPciAddressTestCase(test.NoDBTestCase): utils.get_vf_num_by_pci_address, self.pci_address ) - - @mock.patch.object(os, 'readlink') - @mock.patch.object(glob, 'iglob') - def test_exception(self, mock_iglob, mock_readlink): - mock_iglob.return_value = self.paths - mock_readlink.side_effect = OSError('No such file or directory') - self.assertRaises( - exception.PciDeviceNotFoundById, - utils.get_vf_num_by_pci_address, - self.pci_address - )