typing: Add typing to osc_lib.cli

Change-Id: I56d2095a504694b5f2035fc81ed3836731e0e25f
Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
Stephen Finucane
2025-02-19 20:12:40 +00:00
parent 0fc2851044
commit ebbcb69f4a
7 changed files with 198 additions and 75 deletions

View File

@@ -14,7 +14,9 @@
"""OpenStackConfig subclass for argument compatibility""" """OpenStackConfig subclass for argument compatibility"""
import logging import logging
import typing as ty
from keystoneauth1.loading import identity as ksa_loading
from openstack.config import exceptions as sdk_exceptions from openstack.config import exceptions as sdk_exceptions
from openstack.config import loader as config from openstack.config import loader as config
from oslo_utils import strutils from oslo_utils import strutils
@@ -26,7 +28,9 @@ LOG = logging.getLogger(__name__)
# Sublcass OpenStackConfig in order to munge config values # Sublcass OpenStackConfig in order to munge config values
# before auth plugins are loaded # before auth plugins are loaded
class OSC_Config(config.OpenStackConfig): class OSC_Config(config.OpenStackConfig):
def _auth_select_default_plugin(self, config): def _auth_select_default_plugin(
self, config: dict[str, ty.Any]
) -> dict[str, ty.Any]:
"""Select a default plugin based on supplied arguments """Select a default plugin based on supplied arguments
Migrated from auth.select_auth_plugin() Migrated from auth.select_auth_plugin()
@@ -59,7 +63,9 @@ class OSC_Config(config.OpenStackConfig):
LOG.debug("Auth plugin {} selected".format(config['auth_type'])) LOG.debug("Auth plugin {} selected".format(config['auth_type']))
return config return config
def _auth_v2_arguments(self, config): def _auth_v2_arguments(
self, config: dict[str, ty.Any]
) -> dict[str, ty.Any]:
"""Set up v2-required arguments from v3 info """Set up v2-required arguments from v3 info
Migrated from auth.build_auth_params() Migrated from auth.build_auth_params()
@@ -72,18 +78,24 @@ class OSC_Config(config.OpenStackConfig):
config['auth']['tenant_name'] = config['auth']['project_name'] config['auth']['tenant_name'] = config['auth']['project_name']
return config return config
def _auth_v2_ignore_v3(self, config): def _auth_v2_ignore_v3(
self, config: dict[str, ty.Any]
) -> dict[str, ty.Any]:
"""Remove v3 arguments if present for v2 plugin """Remove v3 arguments if present for v2 plugin
Migrated from clientmanager.setup_auth() Migrated from clientmanager.setup_auth()
""" """
# NOTE(hieulq): If USER_DOMAIN_NAME, USER_DOMAIN_ID, PROJECT_DOMAIN_ID # NOTE(hieulq): If USER_DOMAIN_NAME, USER_DOMAIN_ID, PROJECT_DOMAIN_ID
# or PROJECT_DOMAIN_NAME is present and API_VERSION is 2.0, then # or PROJECT_DOMAIN_NAME is present and API_VERSION is 2.0, then
# ignore all domain related configs. # ignore all domain related configs.
if str(config.get('identity_api_version', '')).startswith( if not str(config.get('identity_api_version', '')).startswith('2'):
'2' return config
) and config.get('auth_type').endswith('password'):
if not config.get('auth_type') or not config['auth_type'].endswith(
'password'
):
return config
domain_props = [ domain_props = [
'project_domain_id', 'project_domain_id',
'project_domain_name', 'project_domain_name',
@@ -106,7 +118,9 @@ class OSC_Config(config.OpenStackConfig):
) )
return config return config
def _auth_default_domain(self, config): def _auth_default_domain(
self, config: dict[str, ty.Any]
) -> dict[str, ty.Any]:
"""Set a default domain from available arguments """Set a default domain from available arguments
Migrated from clientmanager.setup_auth() Migrated from clientmanager.setup_auth()
@@ -147,7 +161,7 @@ class OSC_Config(config.OpenStackConfig):
config['auth']['user_domain_id'] = default_domain config['auth']['user_domain_id'] = default_domain
return config return config
def auth_config_hook(self, config): def auth_config_hook(self, config: dict[str, ty.Any]) -> dict[str, ty.Any]:
"""Allow examination of config values before loading auth plugin """Allow examination of config values before loading auth plugin
OpenStackClient will override this to perform additional checks OpenStackClient will override this to perform additional checks
@@ -165,7 +179,11 @@ class OSC_Config(config.OpenStackConfig):
) )
return config return config
def _validate_auth(self, config, loader, fixed_argparse=None): def _validate_auth(
self,
config: dict[str, ty.Any],
loader: ksa_loading.BaseIdentityLoader[ty.Any],
) -> dict[str, ty.Any]:
"""Validate auth plugin arguments""" """Validate auth plugin arguments"""
# May throw a keystoneauth1.exceptions.NoMatchingPlugin # May throw a keystoneauth1.exceptions.NoMatchingPlugin
@@ -229,7 +247,8 @@ class OSC_Config(config.OpenStackConfig):
return config return config
def load_auth_plugin(self, config): # TODO(stephenfin): Add type once we have typing for SDK
def load_auth_plugin(self, config: dict[str, ty.Any]) -> ty.Any:
"""Get auth plugin and validate args""" """Get auth plugin and validate args"""
loader = self._get_auth_loader(config) loader = self._get_auth_loader(config)

View File

@@ -25,45 +25,45 @@ from osc_lib import utils
class DictColumn(columns.FormattableColumn[dict[str, ty.Any]]): class DictColumn(columns.FormattableColumn[dict[str, ty.Any]]):
"""Format column for dict content""" """Format column for dict content"""
def human_readable(self): def human_readable(self) -> str:
return utils.format_dict(self._value) return utils.format_dict(self._value)
def machine_readable(self): def machine_readable(self) -> dict[str, ty.Any]:
return dict(self._value or {}) return dict(self._value or {})
class DictListColumn(columns.FormattableColumn[dict[str, list[ty.Any]]]): class DictListColumn(columns.FormattableColumn[dict[str, list[ty.Any]]]):
"""Format column for dict, key is string, value is list""" """Format column for dict, key is string, value is list"""
def human_readable(self): def human_readable(self) -> str:
return utils.format_dict_of_list(self._value) return utils.format_dict_of_list(self._value) or ''
def machine_readable(self): def machine_readable(self) -> dict[str, list[ty.Any]]:
return dict(self._value or {}) return dict(self._value or {})
class ListColumn(columns.FormattableColumn[list[ty.Any]]): class ListColumn(columns.FormattableColumn[list[ty.Any]]):
"""Format column for list content""" """Format column for list content"""
def human_readable(self): def human_readable(self) -> str:
return utils.format_list(self._value) return utils.format_list(self._value) or ''
def machine_readable(self): def machine_readable(self) -> list[ty.Any]:
return [x for x in self._value or []] return [x for x in self._value or []]
class ListDictColumn(columns.FormattableColumn[list[dict[str, ty.Any]]]): class ListDictColumn(columns.FormattableColumn[list[dict[str, ty.Any]]]):
"""Format column for list of dict content""" """Format column for list of dict content"""
def human_readable(self): def human_readable(self) -> str:
return utils.format_list_of_dicts(self._value) return utils.format_list_of_dicts(self._value) or ''
def machine_readable(self): def machine_readable(self) -> list[dict[str, ty.Any]]:
return [dict(x) for x in self._value or []] return [dict(x) for x in self._value or []]
class SizeColumn(columns.FormattableColumn[ty.Union[int, float]]): class SizeColumn(columns.FormattableColumn[ty.Union[int, float]]):
"""Format column for file size content""" """Format column for file size content"""
def human_readable(self): def human_readable(self) -> str:
return utils.format_size(self._value) return utils.format_size(self._value)

View File

@@ -11,13 +11,19 @@
# under the License. # under the License.
# #
import argparse
import typing as ty
from openstack import connection
from openstack import exceptions from openstack import exceptions
from openstack.identity.v3 import project from openstack.identity.v3 import project
from osc_lib.i18n import _ from osc_lib.i18n import _
def add_project_owner_option_to_parser(parser): def add_project_owner_option_to_parser(
parser: argparse.ArgumentParser,
) -> None:
"""Register project and project domain options. """Register project and project domain options.
:param parser: argparse.Argument parser object. :param parser: argparse.Argument parser object.
@@ -38,7 +44,13 @@ def add_project_owner_option_to_parser(parser):
) )
def find_project(sdk_connection, name_or_id, domain_name_or_id=None): # TODO(stephenfin): This really doesn't belong here. This should be part of
# openstacksdk itself.
def find_project(
sdk_connection: connection.Connection,
name_or_id: str,
domain_name_or_id: ty.Optional[str] = None,
) -> project.Project:
"""Find a project by its name name or ID. """Find a project by its name name or ID.
If Forbidden to find the resource (a common case if the user does not have If Forbidden to find the resource (a common case if the user does not have
@@ -53,7 +65,6 @@ def find_project(sdk_connection, name_or_id, domain_name_or_id=None):
This can be used when there are multiple projects with a same name. This can be used when there are multiple projects with a same name.
:returns: the project object found :returns: the project object found
:rtype: `openstack.identity.v3.project.Project` :rtype: `openstack.identity.v3.project.Project`
""" """
try: try:
if domain_name_or_id: if domain_name_or_id:

View File

@@ -10,11 +10,15 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import argparse
from osc_lib.cli import parseractions from osc_lib.cli import parseractions
from osc_lib.i18n import _ from osc_lib.i18n import _
def add_marker_pagination_option_to_parser(parser): def add_marker_pagination_option_to_parser(
parser: argparse.ArgumentParser,
) -> None:
"""Add marker-based pagination options to the parser. """Add marker-based pagination options to the parser.
APIs that use marker-based paging use the marker and limit query parameters APIs that use marker-based paging use the marker and limit query parameters
@@ -45,7 +49,9 @@ def add_marker_pagination_option_to_parser(parser):
) )
def add_offset_pagination_option_to_parser(parser): def add_offset_pagination_option_to_parser(
parser: argparse.ArgumentParser,
) -> None:
"""Add offset-based pagination options to the parser. """Add offset-based pagination options to the parser.
APIs that use offset-based paging use the offset and limit query parameters APIs that use offset-based paging use the offset and limit query parameters

View File

@@ -16,9 +16,13 @@
"""argparse Custom Actions""" """argparse Custom Actions"""
import argparse import argparse
import collections.abc
import typing as ty
from osc_lib.i18n import _ from osc_lib.i18n import _
_T = ty.TypeVar('_T')
class KeyValueAction(argparse.Action): class KeyValueAction(argparse.Action):
"""A custom action to parse arguments as key=value pairs """A custom action to parse arguments as key=value pairs
@@ -26,7 +30,16 @@ class KeyValueAction(argparse.Action):
Ensures that ``dest`` is a dict and values are strings. Ensures that ``dest`` is a dict and values are strings.
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: ty.Union[str, ty.Sequence[ty.Any], None],
option_string: ty.Optional[str] = None,
) -> None:
if not isinstance(values, str):
raise TypeError('expected str')
# Make sure we have an empty dict rather than None # Make sure we have an empty dict rather than None
if getattr(namespace, self.dest, None) is None: if getattr(namespace, self.dest, None) is None:
setattr(namespace, self.dest, {}) setattr(namespace, self.dest, {})
@@ -39,7 +52,7 @@ class KeyValueAction(argparse.Action):
msg = _("Property key must be specified: %s") msg = _("Property key must be specified: %s")
raise argparse.ArgumentError(self, msg % str(values)) raise argparse.ArgumentError(self, msg % str(values))
else: else:
getattr(namespace, self.dest, {}).update([values_list]) getattr(namespace, self.dest, {}).update(dict([values_list]))
else: else:
msg = _("Expected 'key=value' type, but got: %s") msg = _("Expected 'key=value' type, but got: %s")
raise argparse.ArgumentError(self, msg % str(values)) raise argparse.ArgumentError(self, msg % str(values))
@@ -51,7 +64,16 @@ class KeyValueAppendAction(argparse.Action):
Ensures that ``dest`` is a dict and values are lists of strings. Ensures that ``dest`` is a dict and values are lists of strings.
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: ty.Union[str, ty.Sequence[ty.Any], None],
option_string: ty.Optional[str] = None,
) -> None:
if not isinstance(values, str):
raise TypeError('expected str')
# Make sure we have an empty dict rather than None # Make sure we have an empty dict rather than None
if getattr(namespace, self.dest, None) is None: if getattr(namespace, self.dest, None) is None:
setattr(namespace, self.dest, {}) setattr(namespace, self.dest, {})
@@ -86,13 +108,19 @@ class MultiKeyValueAction(argparse.Action):
def __init__( def __init__(
self, self,
option_strings, option_strings: ty.Sequence[str],
dest, dest: str,
nargs=None, nargs: ty.Union[int, str, None] = None,
required_keys=None, required_keys: ty.Optional[ty.Sequence[str]] = None,
optional_keys=None, optional_keys: ty.Optional[ty.Sequence[str]] = None,
**kwargs, const: ty.Optional[_T] = None,
): default: ty.Union[_T, str, None] = None,
type: ty.Optional[collections.abc.Callable[[str], _T]] = None,
choices: ty.Optional[collections.abc.Iterable[_T]] = None,
required: bool = False,
help: ty.Optional[str] = None,
metavar: ty.Union[str, tuple[str, ...], None] = None,
) -> None:
"""Initialize the action object, and parse customized options """Initialize the action object, and parse customized options
Required keys and optional keys can be specified when initializing Required keys and optional keys can be specified when initializing
@@ -106,12 +134,24 @@ class MultiKeyValueAction(argparse.Action):
msg = _("Parameter 'nargs' is not allowed, but got %s") msg = _("Parameter 'nargs' is not allowed, but got %s")
raise ValueError(msg % nargs) raise ValueError(msg % nargs)
super().__init__(option_strings, dest, **kwargs) super().__init__(
option_strings,
dest,
nargs=nargs,
const=const,
default=default,
type=type,
choices=choices,
required=required,
help=help,
metavar=metavar,
)
# required_keys: A list of keys that is required. None by default. # required_keys: A list of keys that is required. None by default.
if required_keys and not isinstance(required_keys, list): if required_keys and not isinstance(required_keys, list):
msg = _("'required_keys' must be a list") msg = _("'required_keys' must be a list")
raise TypeError(msg) raise TypeError(msg)
self.required_keys = set(required_keys or []) self.required_keys = set(required_keys or [])
# optional_keys: A list of keys that is optional. None by default. # optional_keys: A list of keys that is optional. None by default.
@@ -120,7 +160,7 @@ class MultiKeyValueAction(argparse.Action):
raise TypeError(msg) raise TypeError(msg)
self.optional_keys = set(optional_keys or []) self.optional_keys = set(optional_keys or [])
def validate_keys(self, keys): def validate_keys(self, keys: ty.Sequence[str]) -> None:
"""Validate the provided keys. """Validate the provided keys.
:param keys: A list of keys to validate. :param keys: A list of keys to validate.
@@ -159,7 +199,16 @@ class MultiKeyValueAction(argparse.Action):
}, },
) )
def __call__(self, parser, namespace, values, metavar=None): def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: ty.Union[str, ty.Sequence[ty.Any], None],
option_string: ty.Optional[str] = None,
) -> None:
if not isinstance(values, str):
raise TypeError('expected str')
# Make sure we have an empty list rather than None # Make sure we have an empty list rather than None
if getattr(namespace, self.dest, None) is None: if getattr(namespace, self.dest, None) is None:
setattr(namespace, self.dest, []) setattr(namespace, self.dest, [])
@@ -196,12 +245,22 @@ class MultiKeyValueCommaAction(MultiKeyValueAction):
Ex. key1=val1,val2,key2=val3 => {"key1": "val1,val2", "key2": "val3"} Ex. key1=val1,val2,key2=val3 => {"key1": "val1,val2", "key2": "val3"}
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: ty.Union[str, ty.Sequence[ty.Any], None],
option_string: ty.Optional[str] = None,
) -> None:
"""Overwrite the __call__ function of MultiKeyValueAction """Overwrite the __call__ function of MultiKeyValueAction
This is done to handle scenarios where we may have comma seperated This is done to handle scenarios where we may have comma seperated
data as a single value. data as a single value.
""" """
if not isinstance(values, str):
msg = _("Invalid key=value pair, non-string value provided: %s")
raise argparse.ArgumentError(self, msg % str(values))
# Make sure we have an empty list rather than None # Make sure we have an empty list rather than None
if getattr(namespace, self.dest, None) is None: if getattr(namespace, self.dest, None) is None:
setattr(namespace, self.dest, []) setattr(namespace, self.dest, [])
@@ -245,7 +304,17 @@ class RangeAction(argparse.Action):
'6:9' sets ``dest`` to (6, 9) '6:9' sets ``dest`` to (6, 9)
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: ty.Union[str, ty.Sequence[ty.Any], None],
option_string: ty.Optional[str] = None,
) -> None:
if not isinstance(values, str):
msg = _("Invalid range, non-string value provided")
raise argparse.ArgumentError(self, msg)
range = values.split(':') range = values.split(':')
if len(range) == 0: if len(range) == 0:
# Nothing passed, return a zero default # Nothing passed, return a zero default
@@ -279,7 +348,17 @@ class NonNegativeAction(argparse.Action):
Ensures the value is >= 0. Ensures the value is >= 0.
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: ty.Union[str, ty.Sequence[ty.Any], None],
option_string: ty.Optional[str] = None,
) -> None:
if not isinstance(values, (str, int, float)):
msg = _("%s expected a non-negative integer")
raise argparse.ArgumentError(self, msg % str(option_string))
if int(values) >= 0: if int(values) >= 0:
setattr(namespace, self.dest, values) setattr(namespace, self.dest, values)
else: else:

View File

@@ -303,7 +303,9 @@ def find_resource(manager, name_or_id, **kwargs):
raise exceptions.CommandError(msg % name_or_id) raise exceptions.CommandError(msg % name_or_id)
def format_dict(data, prefix=None): def format_dict(
data: dict[str, ty.Any], prefix: ty.Optional[str] = None
) -> str:
"""Return a formatted string of key value pairs """Return a formatted string of key value pairs
:param data: a dict :param data: a dict
@@ -331,11 +333,13 @@ def format_dict(data, prefix=None):
return output[:-2] return output[:-2]
def format_dict_of_list(data, separator='; '): def format_dict_of_list(
data: ty.Optional[dict[str, list[ty.Any]]], separator: str = '; '
) -> ty.Optional[str]:
"""Return a formatted string of key value pair """Return a formatted string of key value pair
:param data: a dict, key is string, value is a list of string, for example: :param data: a dict, key is string, value is a list of string, for example:
{u'public': [u'2001:db8::8', u'172.24.4.6']} {'public': ['2001:db8::8', '172.24.4.6']}
:param separator: the separator to use between key/value pair :param separator: the separator to use between key/value pair
(default: '; ') (default: '; ')
:return: a string formatted to {'key1'=['value1', 'value2']} with separated :return: a string formatted to {'key1'=['value1', 'value2']} with separated
@@ -356,7 +360,9 @@ def format_dict_of_list(data, separator='; '):
return separator.join(output) return separator.join(output)
def format_list(data, separator=', '): def format_list(
data: ty.Optional[list[ty.Any]], separator: str = ', '
) -> ty.Optional[str]:
"""Return a formatted strings """Return a formatted strings
:param data: a list of strings :param data: a list of strings
@@ -369,7 +375,9 @@ def format_list(data, separator=', '):
return separator.join(sorted(data)) return separator.join(sorted(data))
def format_list_of_dicts(data): def format_list_of_dicts(
data: ty.Optional[list[dict[str, ty.Any]]],
) -> ty.Optional[str]:
"""Return a formatted string of key value pairs for each dict """Return a formatted string of key value pairs for each dict
:param data: a list of dicts :param data: a list of dicts
@@ -381,10 +389,10 @@ def format_list_of_dicts(data):
return '\n'.join(format_dict(i) for i in data) return '\n'.join(format_dict(i) for i in data)
def format_size(size): def format_size(size: ty.Union[int, float, None]) -> str:
"""Display size of a resource in a human readable format """Display size of a resource in a human readable format
:param string size: :param size:
The size of the resource in bytes. The size of the resource in bytes.
:returns: :returns:
@@ -399,13 +407,12 @@ def format_size(size):
base = 1000.0 base = 1000.0
index = 0 index = 0
if size is None: size_ = float(size) if size is not None else 0.0
size = 0 while size_ >= base:
while size >= base:
index = index + 1 index = index + 1
size = size / base size_ = size_ / base
padded = f'{size:.1f}' padded = f'{size_:.1f}'
stripped = padded.rstrip('0').rstrip('.') stripped = padded.rstrip('0').rstrip('.')
return f'{stripped}{suffix[index]}' return f'{stripped}{suffix[index]}'

View File

@@ -35,6 +35,7 @@ ignore_errors = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = [ module = [
"osc_lib.api.*", "osc_lib.api.*",
"osc_lib.cli.*",
"osc_lib.exceptions", "osc_lib.exceptions",
] ]
disallow_untyped_calls = true disallow_untyped_calls = true