Merge "Add typing"

This commit is contained in:
Zuul
2025-11-20 11:42:47 +00:00
committed by Gerrit Code Review
8 changed files with 136 additions and 56 deletions

View File

@@ -24,3 +24,15 @@ repos:
- id: hacking
additional_dependencies: []
exclude: '^(doc|releasenotes|tools)/.*$'
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
hooks:
- id: mypy
additional_dependencies:
- types-WebOb
# keep this in-sync with '[mypy] exclude' in 'pyproject.toml'
exclude: |
(?x)(
doc/.*
| releasenotes/.*
)

View File

@@ -12,7 +12,11 @@
# limitations under the License.
import collections
from collections.abc import Iterable, MutableMapping, Sequence
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from typing_extensions import Self
ENVIRON_HTTP_HEADER_FMT = 'http_{}'
STANDARD_HEADER = 'openstack-api-version'
@@ -24,17 +28,24 @@ class Version(collections.namedtuple('Version', 'major minor')):
Since it is a tuple, it is automatically comparable.
"""
def __new__(cls, major, minor):
max_version: tuple[int, int]
min_version: tuple[int, int]
def __new__(cls, major: int, minor: int) -> 'Self':
"""Add min and max version attributes to the tuple."""
self = super().__new__(cls, major, minor)
self.max_version = (-1, 0)
self.min_version = (-1, 0)
return self
def __str__(self):
def __str__(self) -> str:
return f'{self.major}.{self.minor}'
def matches(self, min_version=None, max_version=None):
def matches(
self,
min_version: tuple[int, int] | None = None,
max_version: tuple[int, int] | None = None,
) -> bool:
"""Is this version within min_version and max_version."""
# NOTE(cdent): min_version and max_version are expected
# to be set by the code that is creating the Version, if
@@ -46,31 +57,34 @@ class Version(collections.namedtuple('Version', 'major minor')):
return min_version <= self <= max_version
def get_version(headers, service_type, legacy_headers=None):
def get_version(
headers: Iterable[tuple[str, str]] | MutableMapping[str, str],
service_type: str,
legacy_headers: Iterable[str] | None = None,
) -> str | None:
"""Parse a microversion out of headers
If headers is not a dict we assume is an iterator of tuple-like headers,
which we will fold into a dict.
The flow is that we first look for the new standard singular header:
* ``openstack-api-version: <service> <version>``
If that's not present we fall back to the headers listed in
``legacy_headers``. These often look like this:
* ``openstack-<service>-api-version: <version>``
* ``openstack-<legacy>-api-version: <version>``
* ``x-openstack-<legacy>-api-version: <version>``
Folded headers are joined by ``,``.
:param headers: The headers of a request, dict or list
:param service_type: The service type being looked for in the headers
:param legacy_headers: Other headers to look at for a version
:returns: a version string or "latest"
:raises: ValueError
If headers is not a dict we assume is an iterator of
tuple-like headers, which we will fold into a dict.
The flow is that we first look for the new standard singular
header:
* openstack-api-version: <service> <version>
If that's not present we fall back to the headers listed in
legacy_headers. These often look like this:
* openstack-<service>-api-version: <version>
* openstack-<legacy>-api-version: <version>
* x-openstack-<legacy>-api-version: <version>
Folded headers are joined by ','.
"""
folded_headers = fold_headers(headers)
@@ -86,7 +100,9 @@ def get_version(headers, service_type, legacy_headers=None):
return None
def check_legacy_headers(headers, legacy_headers):
def check_legacy_headers(
headers: MutableMapping[str, str], legacy_headers: Iterable[str]
) -> str | None:
"""Gather values from old headers."""
for legacy_header in legacy_headers:
try:
@@ -97,7 +113,9 @@ def check_legacy_headers(headers, legacy_headers):
return None
def check_standard_header(headers, service_type):
def check_standard_header(
headers: MutableMapping[str, str], service_type: str
) -> str | None:
"""Parse the standard header to get value for service."""
try:
header = _extract_header_value(headers, STANDARD_HEADER)
@@ -111,8 +129,12 @@ def check_standard_header(headers, service_type):
except (KeyError, ValueError):
return None
return None
def fold_headers(headers):
# we accept Any even though we know this will be a list of 2-item tuples or a
# dict, in order to avoid reworking logic
def fold_headers(headers: Any) -> MutableMapping[str, str]:
"""Turn a list of headers into a folded dict."""
# If it behaves like a dict, return it. Webob uses objects which
# are not dicts, but behave like them.
@@ -120,6 +142,7 @@ def fold_headers(headers):
return dict((k.lower(), v) for k, v in headers.items())
except AttributeError:
pass
header_dict = collections.defaultdict(list)
for header, value in headers:
header_dict[header.lower()].append(value.strip())
@@ -131,24 +154,28 @@ def fold_headers(headers):
return folded_headers
def headers_from_wsgi_environ(environ):
def headers_from_wsgi_environ(
environ: MutableMapping[str, str],
) -> dict[str, str]:
"""Extract all the HTTP_ keys and values from environ to a new dict.
Note that this does not change the keys in any way in the returned
dict. Nor is the incoming environ modified.
Note that this does not change the keys in any way in the returned dict.
Nor is the incoming environ modified.
:param environ: A PEP 3333 compliant WSGI environ dict.
"""
return {key: environ[key] for key in environ if key.startswith('HTTP_')}
def _extract_header_value(headers, header_name):
def _extract_header_value(
headers: MutableMapping[str, str], header_name: str
) -> str:
"""Get the value of a header.
The provided headers is a dict. If a key doesn't exist for
header_name, try using the WSGI environ form of the name.
The provided headers is a dict. If a key doesn't exist for ``header_name``,
try using the WSGI environ form of the name.
Raises KeyError if neither key is found.
:raises: KeyError if neither key is found.
"""
try:
value = headers[header_name]
@@ -160,7 +187,7 @@ def _extract_header_value(headers, header_name):
return value
def parse_version_string(version_string):
def parse_version_string(version_string: str) -> Version:
"""Turn a version string into a Version
:param version_string: A string of two numerals, X.Y.
@@ -178,21 +205,24 @@ def parse_version_string(version_string):
raise TypeError(f'invalid version string: {version_string}; {exc}')
def extract_version(headers, service_type, versions_list):
def extract_version(
headers: Iterable[tuple[str, str]] | MutableMapping[str, str],
service_type: str,
versions_list: Sequence[str],
) -> Version:
"""Extract the microversion from the headers.
There may be multiple headers and some which don't match our
service.
There may be multiple headers and some which don't match our service.
If no version is found then the extracted version is the minimum
available version.
If no version is found then the extracted version is the minimum available
version.
:param headers: Request headers as dict list or WSGI environ
:param service_type: The service_type as a string
:param service_type: The service type as a string
:param versions_list: List of all possible microversions as strings,
sorted from earliest to latest version.
:returns: a Version with the optional min_version and max_version
attributes set.
sorted from earliest to latest version.
:returns: a :class:`~Version` with the optional ``min_version`` and
``max_version`` attributes set.
:raises: ValueError
"""
found_version = get_version(headers, service_type=service_type)

View File

@@ -10,13 +10,27 @@
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""WSGI middleware for getting microversion info."""
from collections.abc import Sequence
from typing import Any, Protocol, TYPE_CHECKING
import webob
import webob.dec
import webob.exc
import microversion_parse
if TYPE_CHECKING:
from _typeshed.wsgi import WSGIApplication
class _JSONFormatter(Protocol):
def __call__(
self, *, body: str, status: str, title: str, environ: dict[str, Any]
) -> Any: ...
class MicroversionMiddleware:
"""WSGI middleware for getting microversion info.
@@ -38,8 +52,12 @@ class MicroversionMiddleware:
"""
def __init__(
self, application, service_type, versions, json_error_formatter=None
):
self,
application: 'WSGIApplication | None',
service_type: str,
versions: Sequence[str],
json_error_formatter: _JSONFormatter | None = None,
) -> None:
"""Create the WSGI middleware.
:param application: The application hosting the service.
@@ -57,7 +75,10 @@ class MicroversionMiddleware:
self.json_error_formatter = json_error_formatter
@webob.dec.wsgify
def __call__(self, req):
def __call__(
self,
req: webob.request.Request,
) -> webob.response.Response | None:
try:
microversion = microversion_parse.extract_version(
req.headers, self.service_type, self.versions

View File

View File

@@ -203,10 +203,7 @@ class TestGetHeaders(testtools.TestCase):
self.assertEqual('11.12', version)
def test_no_headers(self):
headers = {}
version = microversion_parse.get_version(
headers, service_type='compute'
)
version = microversion_parse.get_version({}, service_type='compute')
self.assertEqual(None, version)
def test_unfolded_service(self):

View File

@@ -18,17 +18,13 @@ import microversion_parse
class TestHeadersFromWSGIEnviron(testtools.TestCase):
def test_empty_environ(self):
environ = {}
expected = {}
self.assertEqual(
expected, microversion_parse.headers_from_wsgi_environ(environ)
)
self.assertEqual({}, microversion_parse.headers_from_wsgi_environ({}))
def test_non_empty_no_headers(self):
environ = {'PATH_INFO': '/foo/bar'}
expected = {}
found_headers = microversion_parse.headers_from_wsgi_environ(environ)
self.assertEqual(expected, found_headers)
self.assertEqual(
{}, microversion_parse.headers_from_wsgi_environ(environ)
)
def test_headers(self):
environ = {

View File

@@ -34,6 +34,25 @@ packages = [
"microversion_parse"
]
[tool.mypy]
python_version = "3.10"
show_column_numbers = true
show_error_context = true
strict = true
# keep this in-sync with 'mypy.exclude' in '.pre-commit-config.yaml'
exclude = '''
(?x)(
doc
| releasenotes
)
'''
[[tool.mypy.overrides]]
module = ["microversion_parse.tests.*"]
disallow_untyped_calls = false
disallow_untyped_defs = false
disallow_subclassing_any = false
[tool.ruff]
line-length = 79

View File

@@ -53,3 +53,8 @@ select = H
ignore = H405
exclude = .venv,.git,.tox,dist,*egg,*.egg-info,build,examples,doc
show-source = true
[hacking]
import_exceptions =
collections.abc
typing