From 6a71981e47d3c92fa52010df783ef0c79c577139 Mon Sep 17 00:00:00 2001
From: Stephen Finucane <stephenfin@redhat.com>
Date: Mon, 16 Mar 2020 13:00:15 +0000
Subject: [PATCH] libvirt: Add typing information

As with the 'nova.virt.hardware' module, add typing information here now
so that we can use it during development of later features. This
requires some minor tweaks of code that mypy found confusing.

Part of blueprint use-pcpu-and-vcpu-in-one-instance

Change-Id: Icc7b3d250bb9dd3d162731959185d9e962727247
Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
---
 mypy-files.txt              |   1 +
 nova/virt/libvirt/driver.py | 119 +++++++++++++++++++-----------------
 2 files changed, 64 insertions(+), 56 deletions(-)

diff --git a/mypy-files.txt b/mypy-files.txt
index b0f478a30539..89ad67ca1294 100644
--- a/mypy-files.txt
+++ b/mypy-files.txt
@@ -1,2 +1,3 @@
 nova/virt/hardware.py
 nova/virt/libvirt/__init__.py
+nova/virt/libvirt/driver.py
diff --git a/nova/virt/libvirt/driver.py b/nova/virt/libvirt/driver.py
index aef1aac6c78d..a3fb4387656d 100644
--- a/nova/virt/libvirt/driver.py
+++ b/nova/virt/libvirt/driver.py
@@ -41,6 +41,7 @@ import random
 import shutil
 import tempfile
 import time
+import typing as ty
 import uuid
 
 from castellan import key_manager
@@ -76,6 +77,7 @@ from nova.api.metadata import base as instance_metadata
 from nova.api.metadata import password
 from nova import block_device
 from nova.compute import power_state
+from nova.compute import provider_tree
 from nova.compute import task_states
 from nova.compute import utils as compute_utils
 from nova.compute import vm_states
@@ -127,7 +129,7 @@ from nova.virt.libvirt.volume import remotefs
 from nova.virt import netutils
 from nova.volume import cinder
 
-libvirt = None
+libvirt: ty.Any = None
 
 uefi_logged = False
 
@@ -389,7 +391,7 @@ class LibvirtDriver(driver.ComputeDriver):
         }
 
         self._sysinfo_serial_func = sysinfo_serial_funcs.get(
-            CONF.libvirt.sysinfo_serial)
+            CONF.libvirt.sysinfo_serial, lambda: None)
 
         self.job_tracker = instancejobtracker.InstanceJobTracker()
         self._remotefs = remotefs.RemoteFilesystem()
@@ -406,7 +408,7 @@ class LibvirtDriver(driver.ComputeDriver):
         # every time update_provider_tree() is called.
         # NOTE(sbauza): We only want a read-only cache, this attribute is not
         # intended to be updatable directly
-        self.provider_tree = None
+        self.provider_tree: provider_tree.ProviderTree = None
 
         # driver traits will not change during the runtime of the agent
         # so calcuate them once and save them
@@ -449,10 +451,12 @@ class LibvirtDriver(driver.ComputeDriver):
                     MIN_LIBVIRT_PMEM_SUPPORT)})
 
         # vpmem keyed by name {name: objects.LibvirtVPMEMDevice,...}
-        vpmems_by_name = {}
+        vpmems_by_name: ty.Dict[str, 'objects.LibvirtVPMEMDevice'] = {}
         # vpmem list keyed by resource class
         # {'RC_0': [objects.LibvirtVPMEMDevice, ...], 'RC_1': [...]}
-        vpmems_by_rc = collections.defaultdict(list)
+        vpmems_by_rc: ty.Dict[str, ty.List['objects.LibvirtVPMEMDevice']] = (
+            collections.defaultdict(list)
+        )
 
         vpmems_host = self._get_vpmems_on_host()
         for ns_conf in vpmem_conf:
@@ -1054,10 +1058,10 @@ class LibvirtDriver(driver.ComputeDriver):
     @staticmethod
     def _live_migration_uri(dest):
         uris = {
-            'kvm': 'qemu+%s://%s/system',
-            'qemu': 'qemu+%s://%s/system',
-            'xen': 'xenmigr://%s/system',
-            'parallels': 'parallels+tcp://%s/system',
+            'kvm': 'qemu+%(scheme)s://%(dest)s/system',
+            'qemu': 'qemu+%(scheme)s://%(dest)s/system',
+            'xen': 'xenmigr://%(dest)s/system',
+            'parallels': 'parallels+tcp://%(dest)s/system',
         }
         dest = oslo_netutils.escape_ipv6(dest)
 
@@ -1071,11 +1075,11 @@ class LibvirtDriver(driver.ComputeDriver):
         if uri is None:
             raise exception.LiveMigrationURINotAvailable(virt_type=virt_type)
 
-        str_format = (dest,)
-        if virt_type in ('kvm', 'qemu'):
-            scheme = CONF.libvirt.live_migration_scheme or 'tcp'
-            str_format = (scheme, dest)
-        return uris.get(virt_type) % str_format
+        str_format = {
+            'dest': dest,
+            'scheme': CONF.libvirt.live_migration_scheme or 'tcp',
+        }
+        return uri % str_format
 
     @staticmethod
     def _migrate_uri(dest):
@@ -1462,9 +1466,9 @@ class LibvirtDriver(driver.ComputeDriver):
     def _detach_encrypted_volumes(self, instance, block_device_info):
         """Detaches encrypted volumes attached to instance."""
         disks = self._get_instance_disk_info(instance, block_device_info)
-        encrypted_volumes = filter(dmcrypt.is_encrypted,
-                                   [disk['path'] for disk in disks])
-        for path in encrypted_volumes:
+        for path in [
+            d['path'] for d in disks if dmcrypt.is_encrypted(d['path'])
+        ]:
             dmcrypt.delete_volume(path)
 
     def _get_serial_ports_from_guest(self, guest, mode=None):
@@ -2977,17 +2981,12 @@ class LibvirtDriver(driver.ComputeDriver):
             raise exception.InstanceNotRunning(instance_id=instance.uuid)
 
         # Find dev name
-        my_dev = None
-        active_disk = None
-
         xml = guest.get_xml_desc()
         xml_doc = etree.fromstring(xml)
 
         device_info = vconfig.LibvirtConfigGuest()
         device_info.parse_dom(xml_doc)
 
-        active_disk_object = None
-
         for guest_disk in device_info.devices:
             if (guest_disk.root_name != 'disk'):
                 continue
@@ -2995,17 +2994,21 @@ class LibvirtDriver(driver.ComputeDriver):
             if (guest_disk.target_dev is None or guest_disk.serial is None):
                 continue
 
+            if (
+                guest_disk.source_path is None and
+                guest_disk.source_protocol is None
+            ):
+                continue
+
             if guest_disk.serial == volume_id:
                 my_dev = guest_disk.target_dev
 
-                active_disk = guest_disk.source_path
                 active_protocol = guest_disk.source_protocol
                 active_disk_object = guest_disk
                 break
-
-        if my_dev is None or (active_disk is None and active_protocol is None):
+        else:
             LOG.debug('Domain XML: %s', xml, instance=instance)
-            msg = (_('Disk with id: %s not found attached to instance.')
+            msg = (_("Disk with id '%s' not found attached to instance.")
                    % volume_id)
             raise exception.InternalError(msg)
 
@@ -3881,15 +3884,13 @@ class LibvirtDriver(driver.ComputeDriver):
             target_partition = None
 
         # Handles the key injection.
+        key = None
         if CONF.libvirt.inject_key and instance.get('key_data'):
             key = str(instance.key_data)
-        else:
-            key = None
 
         # Handles the admin password injection.
-        if not CONF.libvirt.inject_password:
-            admin_pass = None
-        else:
+        admin_pass = None
+        if CONF.libvirt.inject_password:
             admin_pass = injection_info.admin_pass
 
         # Handles the network injection.
@@ -4874,7 +4875,7 @@ class LibvirtDriver(driver.ComputeDriver):
         return idmaps
 
     def _get_guest_idmaps(self):
-        id_maps = []
+        id_maps: ty.List[vconfig.LibvirtConfigGuestIDMap] = []
         if CONF.libvirt.virt_type == 'lxc' and CONF.libvirt.uid_maps:
             uid_maps = self._create_idmaps(vconfig.LibvirtConfigGuestUIDMap,
                                            CONF.libvirt.uid_maps)
@@ -6500,7 +6501,7 @@ class LibvirtDriver(driver.ComputeDriver):
             events = []
 
         pause = bool(events)
-        guest = None
+        guest: libvirt_guest.Guest = None
         try:
             with self.virtapi.wait_for_instance_event(
                     instance, events, deadline=timeout,
@@ -6763,7 +6764,7 @@ class LibvirtDriver(driver.ComputeDriver):
         mdev device handles for that GPU
         """
 
-        counts_per_parent = collections.defaultdict(int)
+        counts_per_parent: ty.Dict[str, int] = collections.defaultdict(int)
         mediated_devices = self._get_mediated_devices(types=enabled_vgpu_types)
         for mdev in mediated_devices:
             parent_vgpu_type = self._get_vgpu_type_per_pgpu(mdev['parent'])
@@ -6785,7 +6786,7 @@ class LibvirtDriver(driver.ComputeDriver):
         """
         mdev_capable_devices = self._get_mdev_capable_devices(
             types=enabled_vgpu_types)
-        counts_per_dev = collections.defaultdict(int)
+        counts_per_dev: ty.Dict[str, int] = collections.defaultdict(int)
         for dev in mdev_capable_devices:
             # dev_id is the libvirt name for the PCI device,
             # eg. pci_0000_84_00_0 which matches a PCI address of 0000:84:00.0
@@ -7384,7 +7385,9 @@ class LibvirtDriver(driver.ComputeDriver):
             return cell.get(page_size, 0)
 
         def _get_physnet_numa_affinity():
-            affinities = {cell.id: set() for cell in topology.cells}
+            affinities: ty.Dict[int, ty.Set[str]] = {
+                cell.id: set() for cell in topology.cells
+            }
             for physnet in CONF.neutron.physnets:
                 # This will error out if the group is not registered, which is
                 # exactly what we want as that would be a bug
@@ -7429,11 +7432,12 @@ class LibvirtDriver(driver.ComputeDriver):
             cpuset = cpus & available_shared_cpus
             pcpuset = cpus & available_dedicated_cpus
 
-            siblings = sorted(map(set,
-                                  set(tuple(cpu.siblings)
-                                        if cpu.siblings else ()
-                                      for cpu in cell.cpus)
-                                  ))
+            # de-duplicate and sort the list of CPU sibling sets
+            siblings = sorted(
+                set(x) for x in set(
+                    tuple(cpu.siblings) or () for cpu in cell.cpus
+                )
+            )
 
             cpus &= available_shared_cpus | available_dedicated_cpus
             siblings = [sib & cpus for sib in siblings]
@@ -7593,7 +7597,9 @@ class LibvirtDriver(driver.ComputeDriver):
         # otherwise.
         inv = provider_tree.data(nodename).inventory
         ratios = self._get_allocation_ratios(inv)
-        resources = collections.defaultdict(set)
+        resources: ty.Dict[str, ty.Set['objects.Resource']] = (
+            collections.defaultdict(set)
+        )
         result = {
             orc.MEMORY_MB: {
                 'total': memory_mb,
@@ -7733,11 +7739,11 @@ class LibvirtDriver(driver.ComputeDriver):
             return db_const.MAX_INT
 
     @property
-    def static_traits(self):
+    def static_traits(self) -> ty.Dict[str, bool]:
         if self._static_traits is not None:
             return self._static_traits
 
-        traits = {}
+        traits: ty.Dict[str, bool] = {}
         traits.update(self._get_cpu_traits())
         traits.update(self._get_storage_bus_traits())
         traits.update(self._get_video_model_traits())
@@ -7890,7 +7896,7 @@ class LibvirtDriver(driver.ComputeDriver):
         :return: dict, keyed by PGPU device ID, to count of VGPUs on that
             device
         """
-        vgpu_count_per_pgpu = collections.defaultdict(int)
+        vgpu_count_per_pgpu: ty.Dict[str, int] = collections.defaultdict(int)
         for mdev_uuid in mdev_uuids:
             # libvirt name is like mdev_00ead764_fdc0_46b6_8db9_2963f5c815b4
             dev_name = libvirt_utils.mdev_uuid2name(mdev_uuid)
@@ -8393,7 +8399,7 @@ class LibvirtDriver(driver.ComputeDriver):
         return migrate_data
 
     def _get_resources(self, instance, prefix=None):
-        resources = []
+        resources: 'objects.ResourceList' = []
         if prefix:
             migr_context = instance.migration_context
             attr_name = prefix + 'resources'
@@ -9089,7 +9095,8 @@ class LibvirtDriver(driver.ComputeDriver):
                                 recover_method, block_migration,
                                 migrate_data, finish_event,
                                 disk_paths):
-        on_migration_failure = deque()
+
+        on_migration_failure: ty.Deque[str] = deque()
         data_gb = self._live_migration_data_gb(instance, disk_paths)
         downtime_steps = list(libvirt_migrate.downtime_steps(data_gb))
         migration = migrate_data.migration
@@ -10401,7 +10408,7 @@ class LibvirtDriver(driver.ComputeDriver):
     @staticmethod
     def _get_io_devices(xml_doc):
         """get the list of io devices from the xml document."""
-        result = {"volumes": [], "ifaces": []}
+        result: ty.Dict[str, ty.List[str]] = {"volumes": [], "ifaces": []}
         try:
             doc = etree.fromstring(xml_doc)
         except Exception:
@@ -10851,7 +10858,7 @@ class LibvirtDriver(driver.ComputeDriver):
                            nova.privsep.fs.FS_FORMAT_EXT4,
                            nova.privsep.fs.FS_FORMAT_XFS]
 
-    def _get_vif_model_traits(self):
+    def _get_vif_model_traits(self) -> ty.Dict[str, bool]:
         """Get vif model traits based on the currently enabled virt_type.
 
         Not all traits generated by this function may be valid and the result
@@ -10871,7 +10878,7 @@ class LibvirtDriver(driver.ComputeDriver):
             in supported_models for model in all_models
         }
 
-    def _get_storage_bus_traits(self):
+    def _get_storage_bus_traits(self) -> ty.Dict[str, bool]:
         """Get storage bus traits based on the currently enabled virt_type.
 
         For QEMU and KVM this function uses the information returned by the
@@ -10889,7 +10896,7 @@ class LibvirtDriver(driver.ComputeDriver):
 
         if CONF.libvirt.virt_type in ('qemu', 'kvm'):
             dom_caps = self._host.get_domain_capabilities()
-            supported_buses = set()
+            supported_buses: ty.Set[str] = set()
             for arch_type in dom_caps:
                 for machine_type in dom_caps[arch_type]:
                     supported_buses.update(
@@ -10906,7 +10913,7 @@ class LibvirtDriver(driver.ComputeDriver):
             supported_buses for bus in all_buses
         }
 
-    def _get_video_model_traits(self):
+    def _get_video_model_traits(self) -> ty.Dict[str, bool]:
         """Get video model traits from libvirt.
 
         Not all traits generated by this function may be valid and the result
@@ -10917,7 +10924,7 @@ class LibvirtDriver(driver.ComputeDriver):
         all_models = fields.VideoModel.ALL
 
         dom_caps = self._host.get_domain_capabilities()
-        supported_models = set()
+        supported_models: ty.Set[str] = set()
         for arch_type in dom_caps:
             for machine_type in dom_caps[arch_type]:
                 supported_models.update(
@@ -10930,7 +10937,7 @@ class LibvirtDriver(driver.ComputeDriver):
             in supported_models for model in all_models
         }
 
-    def _get_cpu_traits(self):
+    def _get_cpu_traits(self) -> ty.Dict[str, bool]:
         """Get CPU-related traits to be set and unset on the host's resource
         provider.
 
@@ -10942,7 +10949,7 @@ class LibvirtDriver(driver.ComputeDriver):
 
         return traits
 
-    def _get_cpu_feature_traits(self):
+    def _get_cpu_feature_traits(self) -> ty.Dict[str, bool]:
         """Get CPU traits of VMs based on guest CPU model config.
 
         1. If mode is 'host-model' or 'host-passthrough', use host's
@@ -10981,7 +10988,7 @@ class LibvirtDriver(driver.ComputeDriver):
                 feature_names = [f.name for f in cpu.features]
             return feature_names
 
-        features = set()
+        features: ty.Set[str] = set()
         # Choose a default CPU model when cpu_mode is not specified
         if cpu.mode is None:
             caps.host.cpu.model = libvirt_utils.get_cpu_model_from_arch(