Merge "Add typing"
This commit is contained in:
@@ -23,3 +23,18 @@ repos:
|
||||
- id: ruff-check
|
||||
args: ['--fix', '--unsafe-fixes']
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies: [
|
||||
'keystoneauth',
|
||||
'oslo.i18n',
|
||||
'openstacksdk',
|
||||
]
|
||||
# keep this in-sync with '[tool.mypy] exclude' in 'pyproject.toml'
|
||||
exclude: |
|
||||
(?x)(
|
||||
doc/.*
|
||||
| releasenotes/.*
|
||||
)
|
||||
|
||||
@@ -14,7 +14,11 @@ from oslo_limit._i18n import _
|
||||
|
||||
|
||||
class ProjectOverLimit(Exception):
|
||||
def __init__(self, project_id, over_limit_info_list):
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
over_limit_info_list: list['OverLimitInfo'],
|
||||
) -> None:
|
||||
"""Exception raised when a project goes over one or more limits
|
||||
|
||||
:param project_id: the project id
|
||||
@@ -22,8 +26,10 @@ class ProjectOverLimit(Exception):
|
||||
"""
|
||||
if not isinstance(over_limit_info_list, list):
|
||||
raise ValueError(over_limit_info_list)
|
||||
|
||||
if len(over_limit_info_list) == 0:
|
||||
raise ValueError(over_limit_info_list)
|
||||
|
||||
for info in over_limit_info_list:
|
||||
if not isinstance(info, OverLimitInfo):
|
||||
raise ValueError(over_limit_info_list)
|
||||
@@ -38,13 +44,19 @@ class ProjectOverLimit(Exception):
|
||||
|
||||
|
||||
class OverLimitInfo:
|
||||
def __init__(self, resource_name, limit, current_usage, delta):
|
||||
def __init__(
|
||||
self,
|
||||
resource_name: str,
|
||||
limit: int,
|
||||
current_usage: int,
|
||||
delta: int,
|
||||
):
|
||||
self.resource_name = resource_name
|
||||
self.limit = int(limit)
|
||||
self.current_usage = int(current_usage)
|
||||
self.delta = int(delta)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
template = (
|
||||
"Resource %s is over limit of %s due to "
|
||||
"current usage %s and delta %s"
|
||||
@@ -56,12 +68,12 @@ class OverLimitInfo:
|
||||
self.delta,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class SessionInitError(Exception):
|
||||
def __init__(self, reason):
|
||||
def __init__(self, reason: object) -> None:
|
||||
msg = _("Can't initialise OpenStackSDK session: %(reason)s.") % {
|
||||
'reason': reason
|
||||
}
|
||||
|
||||
@@ -15,13 +15,17 @@ from unittest import mock
|
||||
|
||||
import fixtures as fixtures
|
||||
|
||||
from openstack.identity.v3 import endpoint
|
||||
from openstack.identity.v3 import limit as keystone_limit
|
||||
from openstack.identity.v3 import registered_limit as keystone_rlimit
|
||||
from openstack.identity.v3 import endpoint as _endpoint
|
||||
from openstack.identity.v3 import limit as _limit
|
||||
from openstack.identity.v3 import registered_limit as _registered_limit
|
||||
|
||||
|
||||
class LimitFixture(fixtures.Fixture):
|
||||
def __init__(self, reglimits, projlimits):
|
||||
class LimitFixture(fixtures.Fixture): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
reglimits: dict[str, int],
|
||||
projlimits: dict[str, dict[str, int]],
|
||||
) -> None:
|
||||
"""A fixture for testing code that relies on Keystone Unified Limits.
|
||||
|
||||
:param reglimits: A dictionary of {resource_name: limit} values to
|
||||
@@ -38,40 +42,48 @@ class LimitFixture(fixtures.Fixture):
|
||||
self.projlimits = projlimits
|
||||
|
||||
def get_reglimit_objects(
|
||||
self, service_id=None, region_id=None, resource_name=None
|
||||
):
|
||||
self,
|
||||
service_id: str | None = None,
|
||||
region_id: str | None = None,
|
||||
resource_name: str | None = None,
|
||||
) -> list[_registered_limit.RegisteredLimit]:
|
||||
limits = []
|
||||
for name, value in self.reglimits.items():
|
||||
if resource_name and resource_name != name:
|
||||
continue
|
||||
limit = keystone_rlimit.RegisteredLimit()
|
||||
|
||||
limit = _registered_limit.RegisteredLimit() # type: ignore
|
||||
limit.resource_name = name
|
||||
limit.default_limit = value
|
||||
limits.append(limit)
|
||||
|
||||
return limits
|
||||
|
||||
def get_projlimit_objects(
|
||||
self,
|
||||
service_id=None,
|
||||
region_id=None,
|
||||
resource_name=None,
|
||||
project_id=None,
|
||||
):
|
||||
service_id: str | None = None,
|
||||
region_id: str | None = None,
|
||||
resource_name: str | None = None,
|
||||
project_id: str | None = None,
|
||||
) -> list[_limit.Limit]:
|
||||
limits = []
|
||||
for proj_id, limit_dict in self.projlimits.items():
|
||||
if project_id and project_id != proj_id:
|
||||
continue
|
||||
|
||||
for name, value in limit_dict.items():
|
||||
if resource_name and resource_name != name:
|
||||
continue
|
||||
limit = keystone_limit.Limit()
|
||||
|
||||
limit = _limit.Limit() # type: ignore
|
||||
limit.project_id = proj_id
|
||||
limit.resource_name = name
|
||||
limit.resource_limit = value
|
||||
limits.append(limit)
|
||||
|
||||
return limits
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
# We mock our own cached connection to Keystone
|
||||
@@ -91,7 +103,7 @@ class LimitFixture(fixtures.Fixture):
|
||||
mock_gem.return_value = 'flat'
|
||||
|
||||
# Fake keystone endpoint; no per-service limit distinction
|
||||
fake_endpoint = endpoint.Endpoint()
|
||||
fake_endpoint = _endpoint.Endpoint() # type: ignore
|
||||
fake_endpoint.service_id = "service_id"
|
||||
fake_endpoint.region_id = "region_id"
|
||||
self.mock_conn.get_endpoint.return_value = fake_endpoint
|
||||
|
||||
@@ -10,13 +10,21 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections import defaultdict
|
||||
from collections import namedtuple
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from typing import TypeAlias
|
||||
|
||||
from keystoneauth1 import exceptions as ksa_exceptions
|
||||
from keystoneauth1 import loading
|
||||
from openstack import connection
|
||||
from openstack import exceptions as os_exceptions
|
||||
from openstack.identity.v3 import _proxy as _identity_proxy
|
||||
from openstack.identity.v3 import endpoint as _endpoint
|
||||
from openstack.identity.v3 import limit as _limit
|
||||
from openstack.identity.v3 import registered_limit as _registered_limit
|
||||
from oslo_config import cfg
|
||||
from oslo_log import log
|
||||
|
||||
@@ -25,14 +33,38 @@ from oslo_limit import opts
|
||||
|
||||
CONF = cfg.CONF
|
||||
LOG = log.getLogger(__name__)
|
||||
_SDK_CONNECTION = None
|
||||
opts.register_opts(CONF)
|
||||
|
||||
_SDK_CONNECTION: _identity_proxy.Proxy | None = None
|
||||
|
||||
ProjectUsage = namedtuple('ProjectUsage', ['limit', 'usage'])
|
||||
|
||||
UsageCallbackT: TypeAlias = Callable[[str, list[str]], dict[str, int]]
|
||||
|
||||
def _get_keystone_connection():
|
||||
opts.register_opts(CONF)
|
||||
|
||||
|
||||
class _EnforcerImplProtocol(Protocol):
|
||||
name: str
|
||||
|
||||
def __init__(
|
||||
self, usage_callback: UsageCallbackT, cache: bool = True
|
||||
) -> None: ...
|
||||
|
||||
def get_registered_limits(
|
||||
self, resources_to_check: list[str]
|
||||
) -> list[tuple[str, int]]: ...
|
||||
|
||||
def get_project_limits(
|
||||
self, project_id: str, resource_names: list[str]
|
||||
) -> list[tuple[str, int]]: ...
|
||||
|
||||
def get_project_usage(
|
||||
self, project_id: str, resources_to_check: list[str]
|
||||
) -> dict[str, int]: ...
|
||||
|
||||
def enforce(self, project_id: str, deltas: dict[str, int]) -> None: ...
|
||||
|
||||
|
||||
def _get_keystone_connection() -> _identity_proxy.Proxy:
|
||||
global _SDK_CONNECTION
|
||||
if not _SDK_CONNECTION:
|
||||
try:
|
||||
@@ -72,7 +104,11 @@ def _get_keystone_connection():
|
||||
|
||||
|
||||
class Enforcer:
|
||||
def __init__(self, usage_callback, cache=True):
|
||||
model: _EnforcerImplProtocol
|
||||
|
||||
def __init__(
|
||||
self, usage_callback: UsageCallbackT, cache: bool = True
|
||||
) -> None:
|
||||
"""An object for checking usage against resource limits and requests.
|
||||
|
||||
:param usage_callback: A callable function that accepts a project_id
|
||||
@@ -90,11 +126,13 @@ class Enforcer:
|
||||
self.connection = _get_keystone_connection()
|
||||
self.model = self._get_model_impl(usage_callback, cache=cache)
|
||||
|
||||
def _get_enforcement_model(self):
|
||||
def _get_enforcement_model(self) -> str:
|
||||
"""Query keystone for the configured enforcement model."""
|
||||
return self.connection.get('/limits/model').json()['model']['name']
|
||||
return self.connection.get('/limits/model').json()['model']['name'] # type: ignore
|
||||
|
||||
def _get_model_impl(self, usage_callback, cache=True):
|
||||
def _get_model_impl(
|
||||
self, usage_callback: UsageCallbackT, cache: bool = True
|
||||
) -> _EnforcerImplProtocol:
|
||||
"""get the enforcement model based on configured model in keystone."""
|
||||
model = self._get_enforcement_model()
|
||||
for impl in _MODELS:
|
||||
@@ -102,7 +140,7 @@ class Enforcer:
|
||||
return impl(usage_callback, cache=cache)
|
||||
raise ValueError(f"enforcement model {model} is not supported")
|
||||
|
||||
def enforce(self, project_id, deltas):
|
||||
def enforce(self, project_id: str, deltas: dict[str, int]) -> None:
|
||||
"""Check resource usage against limits for resources in deltas
|
||||
|
||||
From the deltas we extract the list of resource types that need to
|
||||
@@ -156,7 +194,9 @@ class Enforcer:
|
||||
|
||||
self.model.enforce(project_id, deltas)
|
||||
|
||||
def calculate_usage(self, project_id, resources_to_check):
|
||||
def calculate_usage(
|
||||
self, project_id: str, resources_to_check: list[str]
|
||||
) -> dict[str, ProjectUsage]:
|
||||
"""Calculate resource usage and limits for resources_to_check.
|
||||
|
||||
From the list of resources_to_check, we collect the project's
|
||||
@@ -203,30 +243,42 @@ class Enforcer:
|
||||
for resource, limit in limits
|
||||
}
|
||||
|
||||
def get_registered_limits(self, resources_to_check):
|
||||
def get_registered_limits(
|
||||
self, resources_to_check: list[str]
|
||||
) -> list[tuple[str, int]]:
|
||||
return self.model.get_registered_limits(resources_to_check)
|
||||
|
||||
def get_project_limits(self, project_id, resources_to_check):
|
||||
def get_project_limits(
|
||||
self, project_id: str, resources_to_check: list[str]
|
||||
) -> list[tuple[str, int]]:
|
||||
return self.model.get_project_limits(project_id, resources_to_check)
|
||||
|
||||
|
||||
class _FlatEnforcer:
|
||||
name = 'flat'
|
||||
|
||||
def __init__(self, usage_callback, cache=True):
|
||||
def __init__(
|
||||
self, usage_callback: UsageCallbackT, cache: bool = True
|
||||
) -> None:
|
||||
self._usage_callback = usage_callback
|
||||
self._utils = _EnforcerUtils(cache=cache)
|
||||
|
||||
def get_registered_limits(self, resources_to_check):
|
||||
def get_registered_limits(
|
||||
self, resources_to_check: list[str]
|
||||
) -> list[tuple[str, int]]:
|
||||
return self._utils.get_registered_limits(resources_to_check)
|
||||
|
||||
def get_project_limits(self, project_id, resources_to_check):
|
||||
def get_project_limits(
|
||||
self, project_id: str, resources_to_check: list[str]
|
||||
) -> list[tuple[str, int]]:
|
||||
return self._utils.get_project_limits(project_id, resources_to_check)
|
||||
|
||||
def get_project_usage(self, project_id, resources_to_check):
|
||||
def get_project_usage(
|
||||
self, project_id: str, resources_to_check: list[str]
|
||||
) -> dict[str, int]:
|
||||
return self._usage_callback(project_id, resources_to_check)
|
||||
|
||||
def enforce(self, project_id, deltas):
|
||||
def enforce(self, project_id: str, deltas: dict[str, int]) -> None:
|
||||
resources_to_check = list(deltas.keys())
|
||||
# Always check the limits in the same order, for predictable errors
|
||||
resources_to_check.sort()
|
||||
@@ -244,27 +296,38 @@ class _FlatEnforcer:
|
||||
class _StrictTwoLevelEnforcer:
|
||||
name = 'strict-two-level'
|
||||
|
||||
def __init__(self, usage_callback, cache=True):
|
||||
def __init__(
|
||||
self, usage_callback: UsageCallbackT, cache: bool = True
|
||||
) -> None:
|
||||
self._usage_callback = usage_callback
|
||||
|
||||
def get_registered_limits(self, resources_to_check):
|
||||
def get_registered_limits(
|
||||
self, resources_to_check: list[str]
|
||||
) -> list[tuple[str, int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_project_limits(self, project_id, resources_to_check):
|
||||
def get_project_limits(
|
||||
self, project_id: str, resources_to_check: list[str]
|
||||
) -> list[tuple[str, int]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_project_usage(self, project_id, resources_to_check):
|
||||
def get_project_usage(
|
||||
self, project_id: str, resources_to_check: list[str]
|
||||
) -> dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def enforce(self, project_id, deltas):
|
||||
def enforce(self, project_id: str, deltas: dict[str, int]) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
_MODELS = [_FlatEnforcer, _StrictTwoLevelEnforcer]
|
||||
_MODELS: list[type[_EnforcerImplProtocol]] = [
|
||||
_FlatEnforcer,
|
||||
_StrictTwoLevelEnforcer,
|
||||
]
|
||||
|
||||
|
||||
class _LimitNotFound(Exception):
|
||||
def __init__(self, resource):
|
||||
def __init__(self, resource: str) -> None:
|
||||
msg = f"Can't find the limit for resource {resource}"
|
||||
self.resource = resource
|
||||
super().__init__(msg)
|
||||
@@ -273,37 +336,39 @@ class _LimitNotFound(Exception):
|
||||
class _EnforcerUtils:
|
||||
"""Logic common used by multiple enforcers"""
|
||||
|
||||
def __init__(self, cache=True):
|
||||
def __init__(self, cache: bool = True) -> None:
|
||||
self.connection = _get_keystone_connection()
|
||||
self.should_cache = cache
|
||||
# {project_id: {resource_name: project_limit}}
|
||||
self.plimit_cache = defaultdict(dict)
|
||||
self.plimit_cache: dict[str, dict[str, _limit.Limit]] = defaultdict(
|
||||
dict
|
||||
)
|
||||
# {resource_name: registered_limit}
|
||||
self.rlimit_cache = {}
|
||||
self.rlimit_cache: dict[str, _registered_limit.RegisteredLimit] = {}
|
||||
|
||||
self._endpoint = self._get_endpoint()
|
||||
self._service_id = self._endpoint.service_id
|
||||
self._region_id = self._endpoint.region_id
|
||||
self._endpoint: _endpoint.Endpoint = self._get_endpoint()
|
||||
self._service_id: str = self._endpoint.service_id
|
||||
self._region_id: str = self._endpoint.region_id
|
||||
|
||||
def _get_endpoint(self):
|
||||
def _get_endpoint(self) -> _endpoint.Endpoint:
|
||||
endpoint = self._get_endpoint_by_id()
|
||||
if endpoint is not None:
|
||||
return endpoint
|
||||
|
||||
return self._get_endpoint_by_service_lookup()
|
||||
|
||||
def _get_endpoint_by_id(self):
|
||||
def _get_endpoint_by_id(self) -> _endpoint.Endpoint | None:
|
||||
endpoint_id = CONF.oslo_limit.endpoint_id
|
||||
if endpoint_id is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
endpoint = self.connection.get_endpoint(endpoint_id)
|
||||
endpoint = self.connection.get_endpoint(endpoint_id) # type: ignore
|
||||
except os_exceptions.ResourceNotFound:
|
||||
raise ValueError(f"Can't find endpoint for {endpoint_id}")
|
||||
return endpoint
|
||||
return cast(_endpoint.Endpoint, endpoint)
|
||||
|
||||
def _get_endpoint_by_service_lookup(self):
|
||||
def _get_endpoint_by_service_lookup(self) -> _endpoint.Endpoint:
|
||||
service_type = CONF.oslo_limit.endpoint_service_type
|
||||
service_name = CONF.oslo_limit.endpoint_service_name
|
||||
if not service_type and not service_name:
|
||||
@@ -312,7 +377,7 @@ class _EnforcerUtils:
|
||||
)
|
||||
|
||||
try:
|
||||
services = self.connection.services(
|
||||
services = self.connection.services( # type: ignore
|
||||
type=service_type, name=service_name
|
||||
)
|
||||
services = list(services)
|
||||
@@ -324,7 +389,7 @@ class _EnforcerUtils:
|
||||
|
||||
if CONF.oslo_limit.endpoint_region_name is not None:
|
||||
try:
|
||||
regions = self.connection.regions(
|
||||
regions = self.connection.regions( # type: ignore
|
||||
name=CONF.oslo_limit.endpoint_region_name
|
||||
)
|
||||
regions = list(regions)
|
||||
@@ -340,7 +405,7 @@ class _EnforcerUtils:
|
||||
if interface.endswith('URL'):
|
||||
interface = interface[:-3]
|
||||
try:
|
||||
endpoints = self.connection.endpoints(
|
||||
endpoints = self.connection.endpoints( # type: ignore
|
||||
service_id=service_id,
|
||||
region_id=region_id,
|
||||
interface=interface,
|
||||
@@ -352,10 +417,15 @@ class _EnforcerUtils:
|
||||
if len(endpoints) > 1:
|
||||
raise ValueError("Multiple endpoints found")
|
||||
|
||||
return endpoints[0]
|
||||
return cast(_endpoint.Endpoint, endpoints[0])
|
||||
|
||||
@staticmethod
|
||||
def enforce_limits(project_id, limits, current_usage, deltas):
|
||||
def enforce_limits(
|
||||
project_id: str,
|
||||
limits: list[tuple[str, int]],
|
||||
current_usage: dict[str, int],
|
||||
deltas: dict[str, int],
|
||||
) -> None:
|
||||
"""Check that proposed usage is not over given limits
|
||||
|
||||
:param project_id: project being checked or None
|
||||
@@ -385,9 +455,9 @@ class _EnforcerUtils:
|
||||
LOG.debug("hit limit for project: %s", over_limit_list)
|
||||
raise exception.ProjectOverLimit(project_id, over_limit_list)
|
||||
|
||||
def _get_registered_limits(self):
|
||||
def _get_registered_limits(self) -> list[tuple[str, int]]:
|
||||
registered_limits = []
|
||||
reg_limits = self.connection.registered_limits(
|
||||
reg_limits = self.connection.registered_limits( # type: ignore
|
||||
service_id=self._service_id, region_id=self._region_id
|
||||
)
|
||||
for reg_limit in reg_limits:
|
||||
@@ -397,7 +467,9 @@ class _EnforcerUtils:
|
||||
self.rlimit_cache[name] = reg_limit
|
||||
return registered_limits
|
||||
|
||||
def get_registered_limits(self, resource_names):
|
||||
def get_registered_limits(
|
||||
self, resource_names: list[str] | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Get all the default limits for a given resource name list
|
||||
|
||||
:param resource_names: list of resource_name strings
|
||||
@@ -420,7 +492,7 @@ class _EnforcerUtils:
|
||||
|
||||
return registered_limits
|
||||
|
||||
def _get_project_limits(self, project_id):
|
||||
def _get_project_limits(self, project_id: str) -> list[tuple[str, int]]:
|
||||
if project_id is None:
|
||||
# If we were to pass None, we would receive limits for all projects
|
||||
# and we would have to return {project_id: [(name, limit), ...]}
|
||||
@@ -429,7 +501,7 @@ class _EnforcerUtils:
|
||||
raise ValueError('project_id must not be None')
|
||||
|
||||
project_limits = []
|
||||
proj_limits = self.connection.limits(
|
||||
proj_limits = self.connection.limits( # type: ignore
|
||||
service_id=self._service_id,
|
||||
region_id=self._region_id,
|
||||
project_id=project_id,
|
||||
@@ -441,7 +513,9 @@ class _EnforcerUtils:
|
||||
self.plimit_cache[project_id][name] = proj_limit
|
||||
return project_limits
|
||||
|
||||
def get_project_limits(self, project_id, resource_names):
|
||||
def get_project_limits(
|
||||
self, project_id: str, resource_names: list[str] | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Get all the limits for given project a resource_name list
|
||||
|
||||
If a limit is not found, it will be considered to be zero
|
||||
@@ -467,7 +541,7 @@ class _EnforcerUtils:
|
||||
|
||||
return project_limits
|
||||
|
||||
def _get_limit(self, project_id, resource_name):
|
||||
def _get_limit(self, project_id: str | None, resource_name: str) -> int:
|
||||
# If we are configured to cache limits, look in the cache first and use
|
||||
# the cached value if there is one. Else, retrieve the limit and add it
|
||||
# to the cache. Do this for both project limits and registered limits.
|
||||
@@ -480,13 +554,13 @@ class _EnforcerUtils:
|
||||
)
|
||||
|
||||
if project_limit:
|
||||
return project_limit.resource_limit
|
||||
return cast(int, project_limit.resource_limit)
|
||||
|
||||
# If there is no project limit, look for a registered limit.
|
||||
registered_limit = self._get_registered_limit(resource_name)
|
||||
|
||||
if registered_limit:
|
||||
return registered_limit.default_limit
|
||||
return cast(int, registered_limit.default_limit)
|
||||
|
||||
LOG.error(
|
||||
"Unable to find registered limit for resource "
|
||||
@@ -500,7 +574,9 @@ class _EnforcerUtils:
|
||||
)
|
||||
raise _LimitNotFound(resource_name)
|
||||
|
||||
def _get_project_limit(self, project_id, resource_name):
|
||||
def _get_project_limit(
|
||||
self, project_id: str, resource_name: str
|
||||
) -> _limit.Limit | None:
|
||||
# Look in the cache first.
|
||||
if (
|
||||
project_id in self.plimit_cache
|
||||
@@ -509,7 +585,7 @@ class _EnforcerUtils:
|
||||
return self.plimit_cache[project_id][resource_name]
|
||||
|
||||
# Get the limits from keystone.
|
||||
limits = self.connection.limits(
|
||||
limits = self.connection.limits( # type: ignore
|
||||
service_id=self._service_id,
|
||||
region_id=self._region_id,
|
||||
project_id=project_id,
|
||||
@@ -527,13 +603,15 @@ class _EnforcerUtils:
|
||||
|
||||
return limit
|
||||
|
||||
def _get_registered_limit(self, resource_name):
|
||||
def _get_registered_limit(
|
||||
self, resource_name: str
|
||||
) -> _registered_limit.RegisteredLimit | None:
|
||||
# Look in the cache first.
|
||||
if resource_name in self.rlimit_cache:
|
||||
return self.rlimit_cache[resource_name]
|
||||
|
||||
# Get the limits from keystone.
|
||||
reg_limits = self.connection.registered_limits(
|
||||
reg_limits = self.connection.registered_limits( # type: ignore
|
||||
service_id=self._service_id, region_id=self._region_id
|
||||
)
|
||||
reg_limit = None
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from keystoneauth1 import loading
|
||||
from oslo_config import cfg
|
||||
@@ -56,7 +57,7 @@ _options = [
|
||||
_option_group = 'oslo_limit'
|
||||
|
||||
|
||||
def list_opts():
|
||||
def list_opts() -> list[tuple[str, list[cfg.Opt]]]:
|
||||
"""Return a list of oslo.config options available in the library.
|
||||
|
||||
:returns: a list of (group_name, opts) tuples
|
||||
@@ -75,7 +76,7 @@ def list_opts():
|
||||
]
|
||||
|
||||
|
||||
def register_opts(conf):
|
||||
def register_opts(conf: cfg.ConfigOpts) -> None:
|
||||
loading.register_session_conf_options(CONF, _option_group)
|
||||
loading.register_adapter_conf_options(
|
||||
CONF, _option_group, include_deprecated=False
|
||||
@@ -84,6 +85,7 @@ def register_opts(conf):
|
||||
loading.register_auth_conf_options(CONF, _option_group)
|
||||
plugin_name = CONF.oslo_limit.auth_type
|
||||
if plugin_name:
|
||||
plugin_loader: loading.BaseLoader[Any]
|
||||
plugin_loader = loading.get_plugin_loader(plugin_name)
|
||||
plugin_opts = loading.get_auth_plugin_conf_options(plugin_loader)
|
||||
CONF.register_opts(plugin_opts, group=_option_group)
|
||||
|
||||
0
oslo_limit/py.typed
Normal file
0
oslo_limit/py.typed
Normal file
@@ -95,9 +95,9 @@ class TestFixture(base.BaseTestCase):
|
||||
|
||||
def test_calculate_usage(self):
|
||||
# Make sure the usage calculator works with the fixture too
|
||||
u = self.enforcer.calculate_usage('project2', ['widgets'])['widgets']
|
||||
self.assertEqual(3, u.usage)
|
||||
self.assertEqual(10, u.limit)
|
||||
u = self.enforcer.calculate_usage('project2', ['widgets'])
|
||||
self.assertEqual(3, u['widgets'].usage)
|
||||
self.assertEqual(10, u['widgets'].limit)
|
||||
|
||||
u = self.enforcer.calculate_usage('project1', ['widgets', 'sprockets'])
|
||||
self.assertEqual(10, u['sprockets'].usage)
|
||||
|
||||
@@ -16,6 +16,8 @@ test_limit
|
||||
Tests for `limit` module.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
import uuid
|
||||
|
||||
@@ -89,7 +91,7 @@ class TestEnforcer(base.BaseTestCase):
|
||||
|
||||
def test_get_model_impl(self):
|
||||
json = mock.MagicMock()
|
||||
limit._SDK_CONNECTION.get.return_value = json
|
||||
limit._SDK_CONNECTION.get.return_value = json # type: ignore
|
||||
|
||||
json.json.return_value = {"model": {"name": "flat"}}
|
||||
enforcer = limit.Enforcer(self._get_usage_for_project)
|
||||
@@ -195,7 +197,7 @@ class TestEnforcer(base.BaseTestCase):
|
||||
def test_get_registered_limits(self, mock_get_limits):
|
||||
mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)]
|
||||
|
||||
enforcer = limit.Enforcer(lambda: None)
|
||||
enforcer = limit.Enforcer(lambda: None) # type: ignore
|
||||
limits = enforcer.get_registered_limits(["a", "b", "c"])
|
||||
|
||||
mock_get_limits.assert_called_once_with(["a", "b", "c"])
|
||||
@@ -206,7 +208,7 @@ class TestEnforcer(base.BaseTestCase):
|
||||
project_id = uuid.uuid4().hex
|
||||
mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)]
|
||||
|
||||
enforcer = limit.Enforcer(lambda: None)
|
||||
enforcer = limit.Enforcer(lambda: None) # type: ignore
|
||||
limits = enforcer.get_project_limits(project_id, ["a", "b", "c"])
|
||||
|
||||
mock_get_limits.assert_called_once_with(project_id, ["a", "b", "c"])
|
||||
@@ -272,7 +274,7 @@ class TestFlatEnforcer(base.BaseTestCase):
|
||||
def test_get_registered_limits(self, mock_get_limits):
|
||||
mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)]
|
||||
|
||||
enforcer = limit._FlatEnforcer(lambda: None)
|
||||
enforcer = limit._FlatEnforcer(lambda: None) # type: ignore
|
||||
limits = enforcer.get_registered_limits(["a", "b", "c"])
|
||||
|
||||
mock_get_limits.assert_called_once_with(["a", "b", "c"])
|
||||
@@ -283,7 +285,7 @@ class TestFlatEnforcer(base.BaseTestCase):
|
||||
project_id = uuid.uuid4().hex
|
||||
mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)]
|
||||
|
||||
enforcer = limit._FlatEnforcer(lambda: None)
|
||||
enforcer = limit._FlatEnforcer(lambda: None) # type: ignore
|
||||
limits = enforcer.get_project_limits(project_id, ["a", "b", "c"])
|
||||
|
||||
mock_get_limits.assert_called_once_with(project_id, ["a", "b", "c"])
|
||||
@@ -627,7 +629,7 @@ class TestEnforcerUtils(base.BaseTestCase):
|
||||
self.mock_conn.get_endpoint.return_value = fake_endpoint
|
||||
|
||||
# a and c have limits, b doesn't have one
|
||||
empty_iterator = iter([])
|
||||
empty_iterator: Iterable[Any] = iter([])
|
||||
|
||||
a = registered_limit.RegisteredLimit()
|
||||
a.resource_name = "a"
|
||||
@@ -657,11 +659,13 @@ class TestEnforcerUtils(base.BaseTestCase):
|
||||
project_id = uuid.uuid4().hex
|
||||
|
||||
# a is a project limit, b, c and d don't have one
|
||||
empty_iterator = iter([])
|
||||
empty_iterator: Iterable[Any] = iter([])
|
||||
|
||||
a = klimit.Limit()
|
||||
a.resource_name = "a"
|
||||
a.resource_limit = 1
|
||||
a_iterator = iter([a])
|
||||
|
||||
self.mock_conn.limits.side_effect = [
|
||||
a_iterator,
|
||||
empty_iterator,
|
||||
@@ -674,6 +678,7 @@ class TestEnforcerUtils(base.BaseTestCase):
|
||||
b.resource_name = "b"
|
||||
b.default_limit = 2
|
||||
b_iterator = iter([b])
|
||||
|
||||
self.mock_conn.registered_limits.side_effect = [
|
||||
b_iterator,
|
||||
empty_iterator,
|
||||
@@ -697,6 +702,7 @@ class TestEnforcerUtils(base.BaseTestCase):
|
||||
utils = limit._EnforcerUtils(cache=cache)
|
||||
foo_limit = utils._get_project_limit(project_id, 'foo')
|
||||
|
||||
assert foo_limit is not None # narrow type
|
||||
self.assertEqual(3, foo_limit.resource_limit)
|
||||
self.assertEqual(1, fix.mock_conn.limits.call_count)
|
||||
|
||||
@@ -719,6 +725,7 @@ class TestEnforcerUtils(base.BaseTestCase):
|
||||
utils = limit._EnforcerUtils(cache=cache)
|
||||
foo_limit = utils._get_registered_limit('foo')
|
||||
|
||||
assert foo_limit is not None # narrow type
|
||||
self.assertEqual(5, foo_limit.default_limit)
|
||||
self.assertEqual(1, fix.mock_conn.registered_limits.call_count)
|
||||
|
||||
|
||||
@@ -40,6 +40,27 @@ packages = [
|
||||
"oslo_limit"
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
show_column_numbers = true
|
||||
show_error_context = true
|
||||
ignore_missing_imports = true
|
||||
strict = true
|
||||
# keep this in-sync with 'mypy.exclude' in '.pre-commit-config.yaml'
|
||||
exclude = '''
|
||||
(?x)(
|
||||
doc
|
||||
| releasenotes
|
||||
)
|
||||
'''
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["oslo_limit.tests.*"]
|
||||
disallow_untyped_calls = false
|
||||
disallow_untyped_defs = false
|
||||
disallow_subclassing_any = false
|
||||
disallow_any_generics = false
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
|
||||
|
||||
Reference in New Issue
Block a user