From 89533408209fbf23c90db7e036b65851c25dc058 Mon Sep 17 00:00:00 2001 From: Eric Harney Date: Wed, 3 Jun 2020 14:05:39 -0400 Subject: [PATCH] mypy: annotate volume_utils / utils / exc Change-Id: I886600b1712f4c9415e59cea7166289c0870e58c --- cinder/exception.py | 13 +- cinder/tests/unit/test_utils.py | 12 +- cinder/utils.py | 166 +++++++++++------- cinder/volume/volume_utils.py | 300 +++++++++++++++++++++----------- mypy-files.txt | 3 + 5 files changed, 318 insertions(+), 176 deletions(-) diff --git a/cinder/exception.py b/cinder/exception.py index 6c8f7f9391b..42505998084 100644 --- a/cinder/exception.py +++ b/cinder/exception.py @@ -22,6 +22,8 @@ SHOULD include dedicated exception logging. """ +from typing import Union + from oslo_log import log as logging from oslo_versionedobjects import exception as obj_exc import webob.exc @@ -35,7 +37,8 @@ LOG = logging.getLogger(__name__) class ConvertedException(webob.exc.WSGIHTTPException): - def __init__(self, code=500, title="", explanation=""): + def __init__(self, code: int = 500, title: str = "", + explanation: str = ""): self.code = code # There is a strict rule about constructing status line for HTTP: # '...Status-Line, consisting of the protocol version followed by a @@ -66,10 +69,10 @@ class CinderException(Exception): """ message = _("An unknown exception occurred.") code = 500 - headers = {} + headers: dict = {} safe = False - def __init__(self, message=None, **kwargs): + def __init__(self, message: Union[str, tuple] = None, **kwargs): self.kwargs = kwargs self.kwargs['message'] = message @@ -112,7 +115,7 @@ class CinderException(Exception): # with duplicate keyword exception. self.kwargs.pop('message', None) - def _log_exception(self): + def _log_exception(self) -> None: # kwargs doesn't match a variable in the message # log the issue and the kwargs LOG.exception('Exception in string format operation:') @@ -120,7 +123,7 @@ class CinderException(Exception): LOG.error("%(name)s: %(value)s", {'name': name, 'value': value}) - def _should_format(self): + def _should_format(self) -> bool: return self.kwargs['message'] is None or '%(message)' in self.message diff --git a/cinder/tests/unit/test_utils.py b/cinder/tests/unit/test_utils.py index dec973e1299..c35a970bc4e 100644 --- a/cinder/tests/unit/test_utils.py +++ b/cinder/tests/unit/test_utils.py @@ -297,12 +297,12 @@ class TemporaryChownTestCase(test.TestCase): mock_stat.return_value.st_uid = 5678 test_filename = 'a_file' with utils.temporary_chown(test_filename): - mock_exec.assert_called_once_with('chown', 1234, test_filename, + mock_exec.assert_called_once_with('chown', '1234', test_filename, run_as_root=True) mock_getuid.assert_called_once_with() mock_stat.assert_called_once_with(test_filename) - calls = [mock.call('chown', 1234, test_filename, run_as_root=True), - mock.call('chown', 5678, test_filename, run_as_root=True)] + calls = [mock.call('chown', '1234', test_filename, run_as_root=True), + mock.call('chown', '5678', test_filename, run_as_root=True)] mock_exec.assert_has_calls(calls) @mock.patch('os.stat') @@ -312,12 +312,12 @@ class TemporaryChownTestCase(test.TestCase): mock_stat.return_value.st_uid = 5678 test_filename = 'a_file' with utils.temporary_chown(test_filename, owner_uid=9101): - mock_exec.assert_called_once_with('chown', 9101, test_filename, + mock_exec.assert_called_once_with('chown', '9101', test_filename, run_as_root=True) self.assertFalse(mock_getuid.called) mock_stat.assert_called_once_with(test_filename) - calls = [mock.call('chown', 9101, test_filename, run_as_root=True), - mock.call('chown', 5678, test_filename, run_as_root=True)] + calls = [mock.call('chown', '9101', test_filename, run_as_root=True), + mock.call('chown', '5678', test_filename, run_as_root=True)] mock_exec.assert_has_calls(calls) @mock.patch('os.stat') diff --git a/cinder/utils.py b/cinder/utils.py index ecb8d4ae05c..014d7b99e88 100644 --- a/cinder/utils.py +++ b/cinder/utils.py @@ -22,6 +22,7 @@ import contextlib import datetime import functools import inspect +import logging as py_logging import math import multiprocessing import operator @@ -32,6 +33,9 @@ import shutil import stat import sys import tempfile +import typing +from typing import Callable, Dict, Iterable, Iterator, List # noqa: H301 +from typing import Optional, Tuple, Type, Union # noqa: H301 import eventlet from eventlet import tpool @@ -59,7 +63,7 @@ INFINITE_UNKNOWN_VALUES = ('infinite', 'unknown') synchronized = lockutils.synchronized_with_prefix('cinder-') -def as_int(obj, quiet=True): +def as_int(obj: Union[int, float, str], quiet: bool = True) -> int: # Try "2" -> 2 try: return int(obj) @@ -73,10 +77,12 @@ def as_int(obj, quiet=True): # Eck, not sure what this is then. if not quiet: raise TypeError(_("Can not translate %s to integer.") % (obj)) + + obj = typing.cast(int, obj) return obj -def check_exclusive_options(**kwargs): +def check_exclusive_options(**kwargs: dict) -> None: """Checks that only one of the provided options is actually not-none. Iterates over all the kwargs passed in and checks that only one of said @@ -99,24 +105,24 @@ def check_exclusive_options(**kwargs): # # Ex: 'the_key' -> 'the key' if pretty_keys: - names = [k.replace('_', ' ') for k in kwargs] + tnames = [k.replace('_', ' ') for k in kwargs] else: - names = kwargs.keys() - names = ", ".join(sorted(names)) + tnames = list(kwargs.keys()) + names = ", ".join(sorted(tnames)) msg = (_("May specify only one of %s") % (names)) raise exception.InvalidInput(reason=msg) -def execute(*cmd, **kwargs): +def execute(*cmd: str, **kwargs) -> Tuple[str, str]: """Convenience wrapper around oslo's execute() method.""" if 'run_as_root' in kwargs and 'root_helper' not in kwargs: kwargs['root_helper'] = get_root_helper() return processutils.execute(*cmd, **kwargs) -def check_ssh_injection(cmd_list): - ssh_injection_pattern = ['`', '$', '|', '||', ';', '&', '&&', '>', '>>', - '<'] +def check_ssh_injection(cmd_list: List[str]) -> None: + ssh_injection_pattern: Tuple[str, ...] = ('`', '$', '|', '||', ';', '&', + '&&', '>', '>>', '<') # Check whether injection attacks exist for arg in cmd_list: @@ -149,7 +155,8 @@ def check_ssh_injection(cmd_list): raise exception.SSHInjectionThreat(command=cmd_list) -def check_metadata_properties(metadata=None): +def check_metadata_properties( + metadata: Optional[Dict[str, str]]) -> None: """Checks that the volume metadata properties are valid.""" if not metadata: @@ -175,7 +182,9 @@ def check_metadata_properties(metadata=None): raise exception.InvalidVolumeMetadataSize(reason=msg) -def last_completed_audit_period(unit=None): +def last_completed_audit_period(unit: str = None) -> \ + Tuple[Union[datetime.datetime, datetime.timedelta], + Union[datetime.datetime, datetime.timedelta]]: """This method gives you the most recently *completed* audit period. arguments: @@ -196,11 +205,15 @@ def last_completed_audit_period(unit=None): if not unit: unit = CONF.volume_usage_audit_period - offset = 0 + unit = typing.cast(str, unit) + + offset: Union[str, int] = 0 if '@' in unit: unit, offset = unit.split("@", 1) offset = int(offset) + offset = typing.cast(int, offset) + rightnow = timeutils.utcnow() if unit not in ('month', 'day', 'year', 'hour'): raise ValueError('Time period must be hour, day, month or year') @@ -262,7 +275,7 @@ def last_completed_audit_period(unit=None): return (begin, end) -def monkey_patch(): +def monkey_patch() -> None: """Patches decorators for all functions in a specified module. If the CONF.monkey_patch set as True, @@ -309,7 +322,7 @@ def monkey_patch(): decorator("%s.%s" % (module, key), func)) -def make_dev_path(dev, partition=None, base='/dev'): +def make_dev_path(dev: str, partition: str = None, base: str = '/dev') -> str: """Return a path to a particular device. >>> make_dev_path('xvdc') @@ -324,7 +337,7 @@ def make_dev_path(dev, partition=None, base='/dev'): return path -def robust_file_write(directory, filename, data): +def robust_file_write(directory: str, filename: str, data: str) -> None: """Robust file write. Use "write to temp file and rename" model for writing the @@ -360,15 +373,16 @@ def robust_file_write(directory, filename, data): with excutils.save_and_reraise_exception(): LOG.error("Failed to write persistence file: %(path)s.", {'path': os.path.join(directory, filename)}) - if os.path.isfile(tempname): - os.unlink(tempname) + if tempname is not None: + if os.path.isfile(tempname): + os.unlink(tempname) finally: - if dirfd: + if dirfd is not None: os.close(dirfd) @contextlib.contextmanager -def temporary_chown(path, owner_uid=None): +def temporary_chown(path: str, owner_uid: int = None) -> Iterator[None]: """Temporarily chown a path. :params owner_uid: UID of temporary owner (defaults to current user) @@ -386,16 +400,16 @@ def temporary_chown(path, owner_uid=None): orig_uid = os.stat(path).st_uid if orig_uid != owner_uid: - execute('chown', owner_uid, path, run_as_root=True) + execute('chown', str(owner_uid), path, run_as_root=True) try: yield finally: if orig_uid != owner_uid: - execute('chown', orig_uid, path, run_as_root=True) + execute('chown', str(orig_uid), path, run_as_root=True) @contextlib.contextmanager -def tempdir(**kwargs): +def tempdir(**kwargs) -> Iterator[str]: tmpdir = tempfile.mkdtemp(**kwargs) try: yield tmpdir @@ -406,11 +420,11 @@ def tempdir(**kwargs): LOG.debug('Could not remove tmpdir: %s', str(e)) -def get_root_helper(): +def get_root_helper() -> str: return 'sudo cinder-rootwrap %s' % CONF.rootwrap_config -def require_driver_initialized(driver): +def require_driver_initialized(driver) -> None: """Verifies if `driver` is initialized If the driver is not initialized, an exception will be raised. @@ -427,7 +441,7 @@ def require_driver_initialized(driver): log_unsupported_driver_warning(driver) -def log_unsupported_driver_warning(driver): +def log_unsupported_driver_warning(driver) -> None: """Annoy the log about unsupported drivers.""" if not driver.supported: # Check to see if the driver is flagged as supported. @@ -440,22 +454,24 @@ def log_unsupported_driver_warning(driver): 'id': driver.__class__.__name__}) -def get_file_mode(path): +def get_file_mode(path: str) -> int: """This primarily exists to make unit testing easier.""" return stat.S_IMODE(os.stat(path).st_mode) -def get_file_gid(path): +def get_file_gid(path: str) -> int: """This primarily exists to make unit testing easier.""" return os.stat(path).st_gid -def get_file_size(path): +def get_file_size(path: str) -> int: """Returns the file size.""" return os.stat(path).st_size -def _get_disk_of_partition(devpath, st=None): +def _get_disk_of_partition( + devpath: str, + st: os.stat_result = None) -> Tuple[str, os.stat_result]: """Gets a disk device path and status from partition path. Returns a disk device path from a partition device path, and stat for @@ -478,7 +494,9 @@ def _get_disk_of_partition(devpath, st=None): return (devpath, st) -def get_bool_param(param_string, params, default=False): +def get_bool_param(param_string: str, + params: dict, + default: bool = False) -> bool: param = params.get(param_string, default) if not strutils.is_valid_boolstr(param): msg = _("Value '%(param)s' for '%(param_string)s' is not " @@ -488,7 +506,8 @@ def get_bool_param(param_string, params, default=False): return strutils.bool_from_string(param, strict=True) -def get_blkdev_major_minor(path, lookup_for_file=True): +def get_blkdev_major_minor(path: str, + lookup_for_file: bool = True) -> Optional[str]: """Get 'major:minor' number of block device. Get the device's 'major:minor' number of a block device to control @@ -516,8 +535,9 @@ def get_blkdev_major_minor(path, lookup_for_file=True): raise exception.CinderException(msg) -def check_string_length(value, name, min_length=0, max_length=None, - allow_all_spaces=True): +def check_string_length(value: str, name: str, min_length: int = 0, + max_length: int = None, + allow_all_spaces: bool = True) -> None: """Check the length of specified string. :param value: the value of the string @@ -537,7 +557,7 @@ def check_string_length(value, name, min_length=0, max_length=None, raise exception.InvalidInput(reason=msg) -def is_blk_device(dev): +def is_blk_device(dev: str) -> bool: try: if stat.S_ISBLK(os.stat(dev).st_mode): return True @@ -548,30 +568,30 @@ def is_blk_device(dev): class ComparableMixin(object): - def _compare(self, other, method): + def _compare(self, other: object, method: Callable): try: - return method(self._cmpkey(), other._cmpkey()) + return method(self._cmpkey(), other._cmpkey()) # type: ignore except (AttributeError, TypeError): # _cmpkey not implemented, or return different type, # so I can't compare with "other". return NotImplemented - def __lt__(self, other): + def __lt__(self, other: object): return self._compare(other, lambda s, o: s < o) - def __le__(self, other): + def __le__(self, other: object): return self._compare(other, lambda s, o: s <= o) - def __eq__(self, other): + def __eq__(self, other: object): return self._compare(other, lambda s, o: s == o) - def __ge__(self, other): + def __ge__(self, other: object): return self._compare(other, lambda s, o: s >= o) - def __gt__(self, other): + def __gt__(self, other: object): return self._compare(other, lambda s, o: s > o) - def __ne__(self, other): + def __ne__(self, other: object): return self._compare(other, lambda s, o: s != o) @@ -586,8 +606,12 @@ class retry_if_exit_code(tenacity.retry_if_exception): exc.exit_code in self.codes) -def retry(retry_param, interval=1, retries=3, backoff_rate=2, - wait_random=False, retry=tenacity.retry_if_exception_type): +def retry(retry_param: Optional[Type[Exception]], + interval: int = 1, + retries: int = 3, + backoff_rate: int = 2, + wait_random: bool = False, + retry=tenacity.retry_if_exception_type) -> Callable: if retries < 1: raise ValueError('Retries must be greater than or ' @@ -599,7 +623,7 @@ def retry(retry_param, interval=1, retries=3, backoff_rate=2, wait = tenacity.wait_exponential( multiplier=interval, min=0, exp_base=backoff_rate) - def _decorator(f): + def _decorator(f: Callable) -> Callable: @functools.wraps(f) def _wrapper(*args, **kwargs): @@ -618,7 +642,7 @@ def retry(retry_param, interval=1, retries=3, backoff_rate=2, return _decorator -def convert_str(text): +def convert_str(text: Union[str, bytes]) -> str: """Convert to native string. Convert bytes and Unicode strings to native strings: @@ -633,7 +657,8 @@ def convert_str(text): return text -def build_or_str(elements, str_format=None): +def build_or_str(elements: Union[None, str, Iterable[str]], + str_format: str = None) -> str: """Builds a string of elements joined by 'or'. Will join strings with the 'or' word and if a str_format is provided it @@ -651,18 +676,21 @@ def build_or_str(elements, str_format=None): if not isinstance(elements, str): elements = _(' or ').join(elements) + elements = typing.cast(str, elements) + if str_format: return str_format % elements + return elements -def calculate_virtual_free_capacity(total_capacity, - free_capacity, - provisioned_capacity, - thin_provisioning_support, - max_over_subscription_ratio, - reserved_percentage, - thin): +def calculate_virtual_free_capacity(total_capacity: float, + free_capacity: float, + provisioned_capacity: float, + thin_provisioning_support: bool, + max_over_subscription_ratio: float, + reserved_percentage: float, + thin: bool) -> float: """Calculate the virtual free capacity based on thin provisioning support. :param total_capacity: total_capacity_gb of a host_state or pool. @@ -693,8 +721,9 @@ def calculate_virtual_free_capacity(total_capacity, return free -def calculate_max_over_subscription_ratio(capability, - global_max_over_subscription_ratio): +def calculate_max_over_subscription_ratio( + capability: dict, + global_max_over_subscription_ratio: float) -> float: # provisioned_capacity_gb is the apparent total capacity of # all the volumes created on a backend, which is greater than # or equal to allocated_capacity_gb, which is the apparent @@ -752,7 +781,7 @@ def calculate_max_over_subscription_ratio(capability, return max_over_subscription_ratio -def validate_dictionary_string_length(specs): +def validate_dictionary_string_length(specs: dict) -> None: """Check the length of each key and value of dictionary.""" if not isinstance(specs, dict): msg = _('specs must be a dictionary.') @@ -768,7 +797,8 @@ def validate_dictionary_string_length(specs): min_length=0, max_length=255) -def service_expired_time(with_timezone=False): +def service_expired_time( + with_timezone: Optional[bool] = False) -> datetime.datetime: return (timeutils.utcnow(with_timezone=with_timezone) - datetime.timedelta(seconds=CONF.service_down_time)) @@ -794,7 +824,7 @@ def notifications_enabled(conf): return notifications_driver and notifications_driver != {'noop'} -def if_notifications_enabled(f): +def if_notifications_enabled(f: Callable) -> Callable: """Calls decorated method only if notifications are enabled.""" @functools.wraps(f) def wrapped(*args, **kwargs): @@ -807,7 +837,7 @@ def if_notifications_enabled(f): LOG_LEVELS = ('INFO', 'WARNING', 'ERROR', 'DEBUG') -def get_log_method(level_string): +def get_log_method(level_string: str) -> int: level_string = level_string or '' upper_level_string = level_string.upper() if upper_level_string not in LOG_LEVELS: @@ -816,7 +846,7 @@ def get_log_method(level_string): return getattr(logging, upper_level_string) -def set_log_levels(prefix, level_string): +def set_log_levels(prefix: str, level_string: str) -> None: level = get_log_method(level_string) prefix = prefix or '' @@ -825,18 +855,18 @@ def set_log_levels(prefix, level_string): v.logger.setLevel(level) -def get_log_levels(prefix): +def get_log_levels(prefix: str) -> dict: prefix = prefix or '' - return {k: logging.logging.getLevelName(v.logger.getEffectiveLevel()) + return {k: py_logging.getLevelName(v.logger.getEffectiveLevel()) for k, v in logging.get_loggers().items() if k and k.startswith(prefix)} -def paths_normcase_equal(path_a, path_b): +def paths_normcase_equal(path_a: str, path_b: str) -> bool: return os.path.normcase(path_a) == os.path.normcase(path_b) -def create_ordereddict(adict): +def create_ordereddict(adict: dict) -> OrderedDict: """Given a dict, return a sorted OrderedDict.""" return OrderedDict(sorted(adict.items(), key=operator.itemgetter(0))) @@ -859,7 +889,9 @@ class Semaphore(object): return self.semaphore.__exit__(*args) -def semaphore_factory(limit, concurrent_processes): +def semaphore_factory(limit: int, + concurrent_processes: int) -> Union[eventlet.Semaphore, + Semaphore]: """Get a semaphore to limit concurrent operations. The semaphore depends on the limit we want to set and the concurrent @@ -876,7 +908,7 @@ def semaphore_factory(limit, concurrent_processes): return contextlib.suppress() -def limit_operations(func): +def limit_operations(func: Callable) -> Callable: """Decorator to limit the number of concurrent operations. This method decorator expects to have a _semaphore attribute holding an diff --git a/cinder/volume/volume_utils.py b/cinder/volume/volume_utils.py index 23773f9dba5..618550bf5b2 100644 --- a/cinder/volume/volume_utils.py +++ b/cinder/volume/volume_utils.py @@ -30,6 +30,9 @@ import socket import tempfile import time import types +import typing +from typing import Any, BinaryIO, Callable, Dict, IO # noqa: H301 +from typing import List, Optional, Tuple, Union # noqa: H301 import uuid from castellan.common.credentials import keystone_password @@ -69,7 +72,7 @@ CONF = cfg.CONF LOG = logging.getLogger(__name__) -GB = units.Gi +GB: int = units.Gi # These attributes we will attempt to save for the volume if they exist # in the source image metadata. IMAGE_ATTRIBUTES = ( @@ -85,11 +88,13 @@ TRACE_API = False TRACE_METHOD = False -def null_safe_str(s): +def null_safe_str(s: Optional[str]) -> str: return str(s) if s else '' -def _usage_from_volume(context, volume_ref, **kw): +def _usage_from_volume(context: context.RequestContext, + volume_ref: 'objects.Volume', + **kw) -> dict: now = timeutils.utcnow() launched_at = volume_ref['launched_at'] or now created_at = volume_ref['created_at'] or now @@ -131,7 +136,7 @@ def _usage_from_volume(context, volume_ref, **kw): return usage_info -def _usage_from_backup(backup, **kw): +def _usage_from_backup(backup: 'objects.Backup', **kw) -> dict: num_dependent_backups = backup.num_dependent_backups usage_info = dict(tenant_id=backup.project_id, user_id=backup.user_id, @@ -156,8 +161,11 @@ def _usage_from_backup(backup, **kw): @utils.if_notifications_enabled -def notify_about_volume_usage(context, volume, event_suffix, - extra_usage_info=None, host=None): +def notify_about_volume_usage(context: context.RequestContext, + volume: 'objects.Volume', + event_suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -171,9 +179,11 @@ def notify_about_volume_usage(context, volume, event_suffix, @utils.if_notifications_enabled -def notify_about_backup_usage(context, backup, event_suffix, - extra_usage_info=None, - host=None): +def notify_about_backup_usage(context: context.RequestContext, + backup: 'objects.Backup', + event_suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -186,7 +196,9 @@ def notify_about_backup_usage(context, backup, event_suffix, usage_info) -def _usage_from_snapshot(snapshot, context, **extra_usage_info): +def _usage_from_snapshot(snapshot: 'objects.Snapshot', + context: context.RequestContext, + **extra_usage_info) -> dict: # (niedbalski) a snapshot might be related to a deleted # volume, if that's the case, the volume information is still # required for filling the usage_info, so we enforce to read @@ -212,8 +224,11 @@ def _usage_from_snapshot(snapshot, context, **extra_usage_info): @utils.if_notifications_enabled -def notify_about_snapshot_usage(context, snapshot, event_suffix, - extra_usage_info=None, host=None): +def notify_about_snapshot_usage(context: context.RequestContext, + snapshot: 'objects.Snapshot', + event_suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -227,7 +242,8 @@ def notify_about_snapshot_usage(context, snapshot, event_suffix, usage_info) -def _usage_from_capacity(capacity, **extra_usage_info): +def _usage_from_capacity(capacity: Dict[str, Any], + **extra_usage_info) -> Dict[str, Any]: capacity_info = { 'name_to_id': capacity['name_to_id'], @@ -244,8 +260,11 @@ def _usage_from_capacity(capacity, **extra_usage_info): @utils.if_notifications_enabled -def notify_about_capacity_usage(context, capacity, suffix, - extra_usage_info=None, host=None): +def notify_about_capacity_usage(context: context.RequestContext, + capacity: dict, + suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -260,8 +279,11 @@ def notify_about_capacity_usage(context, capacity, suffix, @utils.if_notifications_enabled -def notify_about_replication_usage(context, volume, suffix, - extra_usage_info=None, host=None): +def notify_about_replication_usage(context: context.RequestContext, + volume: 'objects.Volume', + suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -277,8 +299,11 @@ def notify_about_replication_usage(context, volume, suffix, @utils.if_notifications_enabled -def notify_about_replication_error(context, volume, suffix, - extra_error_info=None, host=None): +def notify_about_replication_error(context: context.RequestContext, + volume: 'objects.Volume', + suffix: str, + extra_error_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -293,7 +318,7 @@ def notify_about_replication_error(context, volume, suffix, usage_info) -def _usage_from_consistencygroup(group_ref, **kw): +def _usage_from_consistencygroup(group_ref: 'objects.Group', **kw) -> dict: usage_info = dict(tenant_id=group_ref.project_id, user_id=group_ref.user_id, availability_zone=group_ref.availability_zone, @@ -307,8 +332,11 @@ def _usage_from_consistencygroup(group_ref, **kw): @utils.if_notifications_enabled -def notify_about_consistencygroup_usage(context, group, event_suffix, - extra_usage_info=None, host=None): +def notify_about_consistencygroup_usage(context: context.RequestContext, + group: 'objects.Group', + event_suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -324,7 +352,7 @@ def notify_about_consistencygroup_usage(context, group, event_suffix, usage_info) -def _usage_from_group(group_ref, **kw): +def _usage_from_group(group_ref: 'objects.Group', **kw) -> dict: usage_info = dict(tenant_id=group_ref.project_id, user_id=group_ref.user_id, availability_zone=group_ref.availability_zone, @@ -339,8 +367,11 @@ def _usage_from_group(group_ref, **kw): @utils.if_notifications_enabled -def notify_about_group_usage(context, group, event_suffix, - extra_usage_info=None, host=None): +def notify_about_group_usage(context: context.RequestContext, + group: 'objects.Group', + event_suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -356,7 +387,7 @@ def notify_about_group_usage(context, group, event_suffix, usage_info) -def _usage_from_cgsnapshot(cgsnapshot, **kw): +def _usage_from_cgsnapshot(cgsnapshot: 'objects.CGSnapshot', **kw) -> dict: usage_info = dict( tenant_id=cgsnapshot.project_id, user_id=cgsnapshot.user_id, @@ -370,7 +401,8 @@ def _usage_from_cgsnapshot(cgsnapshot, **kw): return usage_info -def _usage_from_group_snapshot(group_snapshot, **kw): +def _usage_from_group_snapshot(group_snapshot: 'objects.GroupSnapshot', + **kw) -> dict: usage_info = dict( tenant_id=group_snapshot.project_id, user_id=group_snapshot.user_id, @@ -386,8 +418,11 @@ def _usage_from_group_snapshot(group_snapshot, **kw): @utils.if_notifications_enabled -def notify_about_cgsnapshot_usage(context, cgsnapshot, event_suffix, - extra_usage_info=None, host=None): +def notify_about_cgsnapshot_usage(context: context.RequestContext, + cgsnapshot: 'objects.CGSnapshot', + event_suffix: str, + extra_usage_info: dict = None, + host: str = None) -> None: if not host: host = CONF.host @@ -404,8 +439,11 @@ def notify_about_cgsnapshot_usage(context, cgsnapshot, event_suffix, @utils.if_notifications_enabled -def notify_about_group_snapshot_usage(context, group_snapshot, event_suffix, - extra_usage_info=None, host=None): +def notify_about_group_snapshot_usage(context: context.RequestContext, + group_snapshot: 'objects.GroupSnapshot', + event_suffix: str, + extra_usage_info=None, + host: str = None) -> None: if not host: host = CONF.host @@ -421,13 +459,14 @@ def notify_about_group_snapshot_usage(context, group_snapshot, event_suffix, usage_info) -def _check_blocksize(blocksize): +def _check_blocksize(blocksize: Union[str, int]) -> Union[str, int]: # Check if volume_dd_blocksize is valid try: # Rule out zero-sized/negative/float dd blocksize which # cannot be caught by strutils - if blocksize.startswith(('-', '0')) or '.' in blocksize: + if (blocksize.startswith(('-', '0')) or # type: ignore + '.' in blocksize): # type: ignore raise ValueError strutils.string_to_bytes('%sB' % blocksize) except ValueError: @@ -442,7 +481,8 @@ def _check_blocksize(blocksize): return blocksize -def check_for_odirect_support(src, dest, flag='oflag=direct'): +def check_for_odirect_support(src: str, dest: str, + flag: str = 'oflag=direct') -> bool: # Check whether O_DIRECT is supported try: @@ -459,9 +499,12 @@ def check_for_odirect_support(src, dest, flag='oflag=direct'): return False -def _copy_volume_with_path(prefix, srcstr, deststr, size_in_m, blocksize, - sync=False, execute=utils.execute, ionice=None, - sparse=False): +def _copy_volume_with_path(prefix, srcstr: str, deststr: str, + size_in_m: int, blocksize: Union[str, int], + sync: bool = False, + execute: Callable = utils.execute, + ionice=None, + sparse: bool = False) -> None: cmd = prefix[:] if ionice: @@ -514,16 +557,18 @@ def _copy_volume_with_path(prefix, srcstr, deststr, size_in_m, blocksize, {'size_in_m': size_in_m, 'mbps': mbps}) -def _open_volume_with_path(path, mode): +def _open_volume_with_path(path: str, mode: str) -> IO[Any]: try: with utils.temporary_chown(path): handle = open(path, mode) return handle except Exception: LOG.error("Failed to open volume from %(path)s.", {'path': path}) + raise -def _transfer_data(src, dest, length, chunk_size): +def _transfer_data(src: IO, dest: IO, + length: int, chunk_size: int) -> None: """Transfer data between files (Python IO objects).""" chunks = int(math.ceil(length / chunk_size)) @@ -554,15 +599,21 @@ def _transfer_data(src, dest, length, chunk_size): tpool.execute(dest.flush) -def _copy_volume_with_file(src, dest, size_in_m): +def _copy_volume_with_file(src: Union[str, IO], + dest: Union[str, IO], + size_in_m: int) -> None: src_handle = src if isinstance(src, str): src_handle = _open_volume_with_path(src, 'rb') + src_handle = typing.cast(IO, src_handle) + dest_handle = dest if isinstance(dest, str): dest_handle = _open_volume_with_path(dest, 'wb') + dest_handle = typing.cast(IO, dest_handle) + if not src_handle: raise exception.DeviceUnavailable( _("Failed to copy volume, source device unavailable.")) @@ -588,9 +639,12 @@ def _copy_volume_with_file(src, dest, size_in_m): {'size_in_m': size_in_m, 'mbps': mbps}) -def copy_volume(src, dest, size_in_m, blocksize, sync=False, +def copy_volume(src: Union[str, BinaryIO], + dest: Union[str, BinaryIO], + size_in_m: int, + blocksize: Union[str, int], sync=False, execute=utils.execute, ionice=None, throttle=None, - sparse=False): + sparse=False) -> None: """Copy data from the source volume to the destination volume. The parameters 'src' and 'dest' are both typically of type str, which @@ -617,9 +671,12 @@ def copy_volume(src, dest, size_in_m, blocksize, sync=False, _copy_volume_with_file(src, dest, size_in_m) -def clear_volume(volume_size, volume_path, volume_clear=None, - volume_clear_size=None, volume_clear_ionice=None, - throttle=None): +def clear_volume(volume_size: int, + volume_path: str, + volume_clear: str = None, + volume_clear_size: int = None, + volume_clear_ionice: str = None, + throttle=None) -> None: """Unprovision old volumes to prevent data leaking between users.""" if volume_clear is None: volume_clear = CONF.volume_clear @@ -649,24 +706,25 @@ def clear_volume(volume_size, volume_path, volume_clear=None, value=volume_clear) -def supports_thin_provisioning(): +def supports_thin_provisioning() -> bool: return brick_lvm.LVM.supports_thin_provisioning( utils.get_root_helper()) -def get_all_physical_volumes(vg_name=None): +def get_all_physical_volumes(vg_name=None) -> list: return brick_lvm.LVM.get_all_physical_volumes( utils.get_root_helper(), vg_name) -def get_all_volume_groups(vg_name=None): +def get_all_volume_groups(vg_name=None) -> list: return brick_lvm.LVM.get_all_volume_groups( utils.get_root_helper(), vg_name) -def extract_availability_zones_from_volume_type(volume_type): +def extract_availability_zones_from_volume_type(volume_type) \ + -> Optional[list]: if not volume_type: return None extra_specs = volume_type.get('extra_specs', {}) @@ -683,7 +741,9 @@ DEFAULT_PASSWORD_SYMBOLS = ('23456789', # Removed: 0,1 'abcdefghijkmnopqrstuvwxyz') # Removed: l -def generate_password(length=16, symbolgroups=DEFAULT_PASSWORD_SYMBOLS): +def generate_password( + length: int = 16, + symbolgroups: Tuple[str, ...] = DEFAULT_PASSWORD_SYMBOLS) -> str: """Generate a random password from the supplied symbol groups. At least one symbol from each group will be included. Unpredictable @@ -720,7 +780,9 @@ def generate_password(length=16, symbolgroups=DEFAULT_PASSWORD_SYMBOLS): return ''.join(password) -def generate_username(length=20, symbolgroups=DEFAULT_PASSWORD_SYMBOLS): +def generate_username( + length: int = 20, + symbolgroups: Tuple[str, ...] = DEFAULT_PASSWORD_SYMBOLS) -> str: # Use the same implementation as the password generation. return generate_password(length, symbolgroups) @@ -728,7 +790,9 @@ def generate_username(length=20, symbolgroups=DEFAULT_PASSWORD_SYMBOLS): DEFAULT_POOL_NAME = '_pool0' -def extract_host(host, level='backend', default_pool_name=False): +def extract_host(host: Optional[str], + level: str = 'backend', + default_pool_name: bool = False) -> Optional[str]: """Extract Host, Backend or Pool information from host string. :param host: String for host, which could include host@backend#pool info @@ -778,8 +842,11 @@ def extract_host(host, level='backend', default_pool_name=False): else: return None + return None # not hit -def append_host(host, pool): + +def append_host(host: Optional[str], + pool: Optional[str]) -> Optional[str]: """Encode pool into host info.""" if not host or not pool: return host @@ -788,7 +855,7 @@ def append_host(host, pool): return new_host -def matching_backend_name(src_volume_type, volume_type): +def matching_backend_name(src_volume_type, volume_type) -> bool: if src_volume_type.get('volume_backend_name') and \ volume_type.get('volume_backend_name'): return src_volume_type.get('volume_backend_name') == \ @@ -797,14 +864,14 @@ def matching_backend_name(src_volume_type, volume_type): return False -def hosts_are_equivalent(host_1, host_2): +def hosts_are_equivalent(host_1: str, host_2: str) -> bool: # In case host_1 or host_2 are None if not (host_1 and host_2): return host_1 == host_2 return extract_host(host_1) == extract_host(host_2) -def read_proc_mounts(): +def read_proc_mounts() -> List[str]: """Read the /proc/mounts file. It's a dummy function but it eases the writing of unit tests as mocking @@ -814,19 +881,20 @@ def read_proc_mounts(): return mounts.readlines() -def extract_id_from_volume_name(vol_name): - regex = re.compile( +def extract_id_from_volume_name(vol_name: str) -> Optional[str]: + regex: typing.Pattern = re.compile( CONF.volume_name_template.replace('%s', r'(?P.+)')) match = regex.match(vol_name) return match.group('uuid') if match else None -def check_already_managed_volume(vol_id): +def check_already_managed_volume(vol_id: Optional[str]): """Check cinder db for already managed volume. :param vol_id: volume id parameter :returns: bool -- return True, if db entry with specified volume id exists, otherwise return False + :raises: ValueError if vol_id is not a valid uuid string """ try: return (vol_id and isinstance(vol_id, str) and @@ -836,7 +904,7 @@ def check_already_managed_volume(vol_id): return False -def extract_id_from_snapshot_name(snap_name): +def extract_id_from_snapshot_name(snap_name: str) -> Optional[str]: """Return a snapshot's ID from its name on the backend.""" regex = re.compile( CONF.snapshot_name_template.replace('%s', r'(?P.+)')) @@ -844,8 +912,12 @@ def extract_id_from_snapshot_name(snap_name): return match.group('uuid') if match else None -def paginate_entries_list(entries, marker, limit, offset, sort_keys, - sort_dirs): +def paginate_entries_list(entries: List[Dict], + marker: Optional[Union[dict, str]], + limit: int, + offset: Optional[int], + sort_keys: List[str], + sort_dirs: List[str]) -> list: """Paginate a list of entries. :param entries: list of dictionaries @@ -859,7 +931,8 @@ def paginate_entries_list(entries, marker, limit, offset, sort_keys, comparers = [(operator.itemgetter(key.strip()), multiplier) for (key, multiplier) in zip(sort_keys, sort_dirs)] - def comparer(left, right): + def comparer(left, right) -> int: + fn: Callable for fn, d in comparers: left_val = fn(left) right_val = fn(right) @@ -900,7 +973,7 @@ def paginate_entries_list(entries, marker, limit, offset, sort_keys, return sorted_entries[start_index + offset:range_end + offset] -def convert_config_string_to_dict(config_string): +def convert_config_string_to_dict(config_string: str) -> dict: """Convert config file replication string to a dict. The only supported form is as follows: @@ -924,12 +997,16 @@ def convert_config_string_to_dict(config_string): return resultant_dict -def create_encryption_key(context, key_manager, volume_type_id): +def create_encryption_key(context: context.RequestContext, + key_manager, + volume_type_id: str) -> Optional[str]: encryption_key_id = None if volume_types.is_encrypted(context, volume_type_id): - volume_type_encryption = ( + volume_type_encryption: db.sqlalchemy.models.Encryption = ( volume_types.get_volume_type_encryption(context, volume_type_id)) + if volume_type_encryption is None: + raise exception.Invalid(message="Volume type error") cipher = volume_type_encryption.cipher length = volume_type_encryption.key_size algorithm = cipher.split('-')[0] if cipher else None @@ -945,10 +1022,13 @@ def create_encryption_key(context, key_manager, volume_type_id): LOG.exception("Key manager error") raise exception.Invalid(message="Key manager error") + typing.cast(str, encryption_key_id) return encryption_key_id -def delete_encryption_key(context, key_manager, encryption_key_id): +def delete_encryption_key(context: context.RequestContext, + key_manager, + encryption_key_id: str) -> None: try: key_manager.delete(context, encryption_key_id) except castellan_exception.ManagedObjectNotFoundError: @@ -972,7 +1052,9 @@ def delete_encryption_key(context, key_manager, encryption_key_id): pass -def clone_encryption_key(context, key_manager, encryption_key_id): +def clone_encryption_key(context: context.RequestContext, + key_manager, + encryption_key_id: str) -> str: clone_key_id = None if encryption_key_id is not None: clone_key_id = key_manager.store( @@ -981,19 +1063,19 @@ def clone_encryption_key(context, key_manager, encryption_key_id): return clone_key_id -def is_boolean_str(str): +def is_boolean_str(str: Optional[str]) -> bool: spec = (str or '').split() return (len(spec) == 2 and spec[0] == '' and strutils.bool_from_string(spec[1])) -def is_replicated_spec(extra_specs): - return (extra_specs and +def is_replicated_spec(extra_specs: dict) -> bool: + return (bool(extra_specs) and is_boolean_str(extra_specs.get('replication_enabled'))) -def is_multiattach_spec(extra_specs): - return (extra_specs and +def is_multiattach_spec(extra_specs: dict) -> bool: + return (bool(extra_specs) and is_boolean_str(extra_specs.get('multiattach'))) @@ -1003,7 +1085,7 @@ def group_get_by_id(group_id): return group -def is_group_a_cg_snapshot_type(group_or_snap): +def is_group_a_cg_snapshot_type(group_or_snap) -> bool: LOG.debug("Checking if %s is a consistent snapshot group", group_or_snap) if group_or_snap["group_type_id"] is not None: @@ -1015,7 +1097,7 @@ def is_group_a_cg_snapshot_type(group_or_snap): return False -def is_group_a_type(group, key): +def is_group_a_type(group: 'objects.Group', key: str) -> bool: if group.group_type_id is not None: spec = group_types.get_group_type_specs( group.group_type_id, key=key @@ -1024,7 +1106,9 @@ def is_group_a_type(group, key): return False -def get_max_over_subscription_ratio(str_value, supports_auto=False): +def get_max_over_subscription_ratio( + str_value: Union[str, float], + supports_auto: bool = False) -> Union[str, float]: """Get the max_over_subscription_ratio from a string As some drivers need to do some calculations with the value and we are now @@ -1044,6 +1128,7 @@ def get_max_over_subscription_ratio(str_value, supports_auto=False): raise exception.VolumeDriverException(message=msg) if str_value == 'auto': + str_value = typing.cast(str, str_value) return str_value mosr = float(str_value) @@ -1055,7 +1140,8 @@ def get_max_over_subscription_ratio(str_value, supports_auto=False): return mosr -def check_image_metadata(image_meta, vol_size): +def check_image_metadata(image_meta: Dict[str, Union[str, int]], + vol_size: int) -> None: """Validates the image metadata.""" # Check whether image is active if image_meta['status'] != 'active': @@ -1074,6 +1160,7 @@ def check_image_metadata(image_meta, vol_size): # Check image min_disk requirement is met for the particular volume min_disk = image_meta.get('min_disk', 0) + min_disk = typing.cast(int, min_disk) if vol_size < min_disk: msg = _('Volume size %(volume_size)sGB cannot be smaller' ' than the image minDisk size %(min_disk)sGB.') @@ -1081,7 +1168,7 @@ def check_image_metadata(image_meta, vol_size): raise exception.InvalidInput(reason=msg) -def enable_bootable_flag(volume): +def enable_bootable_flag(volume: 'objects.Volume') -> None: try: LOG.debug('Marking volume %s as bootable.', volume.id) volume.bootable = True @@ -1092,7 +1179,8 @@ def enable_bootable_flag(volume): raise exception.MetadataUpdateFailure(reason=ex) -def get_volume_image_metadata(image_id, image_meta): +def get_volume_image_metadata(image_id: str, + image_meta: Dict[str, Any]) -> dict: # Save some base attributes into the volume metadata base_metadata = { @@ -1114,6 +1202,7 @@ def get_volume_image_metadata(image_id, image_meta): # Save all the image metadata properties into the volume metadata property_metadata = {} image_properties = image_meta.get('properties', {}) + image_properties = typing.cast(dict, image_properties) for (key, value) in image_properties.items(): if value is not None: property_metadata[key] = value @@ -1123,8 +1212,12 @@ def get_volume_image_metadata(image_id, image_meta): return volume_metadata -def copy_image_to_volume(driver, context, volume, image_meta, image_location, - image_service): +def copy_image_to_volume(driver, + context: context.RequestContext, + volume: 'objects.Volume', + image_meta: dict, + image_location: str, + image_service) -> None: """Downloads Glance image to the specified volume.""" image_id = image_meta['id'] LOG.debug("Attempting download of %(image_id)s (%(image_location)s)" @@ -1173,7 +1266,7 @@ def copy_image_to_volume(driver, context, volume, image_meta, image_location, 'image_location': image_location}) -def image_conversion_dir(): +def image_conversion_dir() -> str: tmpdir = (CONF.image_conversion_dir or tempfile.gettempdir()) @@ -1184,7 +1277,9 @@ def image_conversion_dir(): return tmpdir -def check_encryption_provider(db, volume, context): +def check_encryption_provider(db, + volume: 'objects.Volume', + context: context.RequestContext) -> dict: """Check that this is a LUKS encryption provider. :returns: encryption dict @@ -1212,14 +1307,14 @@ def check_encryption_provider(db, volume, context): return encryption -def sanitize_host(host): +def sanitize_host(host: str) -> str: """Ensure IPv6 addresses are enclosed in [] for iSCSI portals.""" if netutils.is_valid_ipv6(host): return '[%s]' % host return host -def sanitize_hostname(hostname): +def sanitize_hostname(hostname) -> str: """Return a hostname which conforms to RFC-952 and RFC-1123 specs.""" hostname = hostname.encode('latin-1', 'ignore') hostname = hostname.decode('latin-1') @@ -1232,7 +1327,7 @@ def sanitize_hostname(hostname): return hostname -def resolve_hostname(hostname): +def resolve_hostname(hostname: str) -> str: """Resolves host name to IP address. Resolves a host name (my.data.point.com) to an IP address (10.12.143.11). @@ -1248,7 +1343,9 @@ def resolve_hostname(hostname): return ip -def update_backup_error(backup, err, status=fields.BackupStatus.ERROR): +def update_backup_error(backup, + err: str, + status=fields.BackupStatus.ERROR) -> None: backup.status = status backup.fail_reason = err backup.save() @@ -1256,7 +1353,7 @@ def update_backup_error(backup, err, status=fields.BackupStatus.ERROR): # TODO (whoami-rajat): Remove this method when oslo.vmware calls volume_utils # wrapper of upload_volume instead of image_utils.upload_volume -def get_base_image_ref(volume): +def get_base_image_ref(volume: 'objects.Volume'): # This method fetches the image_id from volume glance metadata and pass # it to the driver calling it during upload volume to image operation base_image_ref = None @@ -1265,9 +1362,12 @@ def get_base_image_ref(volume): return base_image_ref -def upload_volume(context, image_service, image_meta, volume_path, - volume, volume_format='raw', run_as_root=True, - compress=True): +def upload_volume(context: context.RequestContext, + image_service, image_meta, volume_path, + volume: 'objects.Volume', + volume_format: str = 'raw', + run_as_root: bool = True, + compress: bool = True) -> None: # retrieve store information from extra-specs store_id = volume.volume_type.extra_specs.get('image_service:store_id') @@ -1305,7 +1405,8 @@ def get_backend_configuration(backend_name, backend_opts=None): return config -def brick_get_connector_properties(multipath=False, enforce_multipath=False): +def brick_get_connector_properties(multipath: bool = False, + enforce_multipath: bool = False): """Wrapper to automatically set root_helper in brick calls. :param multipath: A boolean indicating whether the connector can @@ -1323,9 +1424,10 @@ def brick_get_connector_properties(multipath=False, enforce_multipath=False): enforce_multipath) -def brick_get_connector(protocol, driver=None, - use_multipath=False, - device_scan_attempts=3, +def brick_get_connector(protocol: str, + driver=None, + use_multipath: bool = False, + device_scan_attempts: int = 3, *args, **kwargs): """Wrapper to get a brick connector object. @@ -1342,7 +1444,7 @@ def brick_get_connector(protocol, driver=None, *args, **kwargs) -def brick_get_encryptor(connection_info, *args, **kwargs): +def brick_get_encryptor(connection_info: dict, *args, **kwargs): """Wrapper to get a brick encryptor object.""" root_helper = utils.get_root_helper() @@ -1353,7 +1455,9 @@ def brick_get_encryptor(connection_info, *args, **kwargs): *args, **kwargs) -def brick_attach_volume_encryptor(context, attach_info, encryption): +def brick_attach_volume_encryptor(context: context.RequestContext, + attach_info: dict, + encryption: dict) -> None: """Attach encryption layer.""" connection_info = attach_info['conn'] connection_info['data']['device_path'] = attach_info['device']['path'] @@ -1362,7 +1466,7 @@ def brick_attach_volume_encryptor(context, attach_info, encryption): encryptor.attach_volume(context, **encryption) -def brick_detach_volume_encryptor(attach_info, encryption): +def brick_detach_volume_encryptor(attach_info: dict, encryption: dict) -> None: """Detach encryption layer.""" connection_info = attach_info['conn'] connection_info['data']['device_path'] = attach_info['device']['path'] diff --git a/mypy-files.txt b/mypy-files.txt index 61c423ece8c..8ca96f3fa6e 100644 --- a/mypy-files.txt +++ b/mypy-files.txt @@ -1,6 +1,9 @@ cinder/context.py cinder/i18n.py +cinder/exception.py cinder/manager.py +cinder/utils.py cinder/volume/__init__.py cinder/volume/manager.py cinder/volume/volume_types.py +cinder/volume/volume_utils.py