Add typing

Change-Id: I4b79a9bc3daaa1109c1bac9a135e8edfbef41ede
Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
Stephen Finucane
2026-01-19 09:13:41 +01:00
parent 6a54bd92db
commit 120e67e5a0
11 changed files with 272 additions and 119 deletions

View File

@@ -12,10 +12,14 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections.abc import Iterable
import enum
import os
import platform
import sys
from typing import Any, TYPE_CHECKING
import cffi
@@ -69,8 +73,54 @@ class Capabilities(enum.IntEnum):
CAP_CHECKPOINT_RESTORE = 40
CAPS_BYNAME = {}
CAPS_BYVALUE = {}
if TYPE_CHECKING:
# this is a bit long-winded, but it's the easiest way to expose these
# attributes via mypy
CAP_CHOWN = 0
CAP_DAC_OVERRIDE = 1
CAP_DAC_READ_SEARCH = 2
CAP_FOWNER = 3
CAP_FSETID = 4
CAP_KILL = 5
CAP_SETGID = 6
CAP_SETUID = 7
CAP_SETPCAP = 8
CAP_LINUX_IMMUTABLE = 9
CAP_NET_BIND_SERVICE = 10
CAP_NET_BROADCAST = 11
CAP_NET_ADMIN = 12
CAP_NET_RAW = 13
CAP_IPC_LOCK = 14
CAP_IPC_OWNER = 15
CAP_SYS_MODULE = 16
CAP_SYS_RAWIO = 17
CAP_SYS_CHROOT = 18
CAP_SYS_PTRACE = 19
CAP_SYS_PACCT = 20
CAP_SYS_ADMIN = 21
CAP_SYS_BOOT = 22
CAP_SYS_NICE = 23
CAP_SYS_RESOURCE = 24
CAP_SYS_TIME = 25
CAP_SYS_TTY_CONFIG = 26
CAP_MKNOD = 27
CAP_LEASE = 28
CAP_AUDIT_WRITE = 29
CAP_AUDIT_CONTROL = 30
CAP_SETFCAP = 31
CAP_MAC_OVERRIDE = 32
CAP_MAC_ADMIN = 33
CAP_SYSLOG = 34
CAP_WAKE_ALARM = 35
CAP_BLOCK_SUSPEND = 36
CAP_AUDIT_READ = 37
CAP_PERFMON = 38
CAP_BPF = 39
CAP_CHECKPOINT_RESTORE = 40
CAPS_BYNAME: dict[str, int] = {}
CAPS_BYVALUE: dict[int, str] = {}
module = sys.modules[__name__]
# Convenience dicts for human readable values
# module attributes for backwards compat/convenience
@@ -114,6 +164,11 @@ ffi = cffi.FFI()
ffi.cdef(CDEF)
crt: Any
_prctl: Any
_capget: Any
_capset: Any
if platform.system() == 'Linux':
# mock.patching crt.* directly seems to upset cffi. Use an
# indirection point here for easier testing.
@@ -127,7 +182,7 @@ else:
_capset = None
def set_keepcaps(enable):
def set_keepcaps(enable: bool) -> None:
"""Set/unset thread's "keep capabilities" flag - see prctl(2)"""
ret = _prctl(crt.PR_SET_KEEPCAPS, ffi.cast('unsigned long', bool(enable)))
if ret != 0:
@@ -135,7 +190,11 @@ def set_keepcaps(enable):
raise OSError(errno, os.strerror(errno))
def drop_all_caps_except(effective, permitted, inheritable):
def drop_all_caps_except(
effective: Iterable[int],
permitted: Iterable[int],
inheritable: Iterable[int],
) -> None:
"""Set (effective, permitted, inheritable) to provided list of caps"""
eff = _caps_to_mask(effective)
prm = _caps_to_mask(permitted)
@@ -159,12 +218,12 @@ def drop_all_caps_except(effective, permitted, inheritable):
raise OSError(errno, os.strerror(errno))
def _mask_to_caps(mask):
def _mask_to_caps(mask: int) -> list[int]:
"""Convert bitmask to list of set bit offsets"""
return [i for i in range(64) if (1 << i) & mask]
def _caps_to_mask(caps):
def _caps_to_mask(caps: Iterable[int]) -> int:
"""Convert list of bit offsets to bitmask"""
mask = 0
for cap in caps:
@@ -172,7 +231,7 @@ def _caps_to_mask(caps):
return mask
def get_caps():
def get_caps() -> tuple[list[int], list[int], list[int]]:
"""Return (effective, permitted, inheritable) as lists of caps"""
header = ffi.new(
'cap_user_header_t',

View File

@@ -20,12 +20,18 @@ python datatypes. Msgpack 'raw' is assumed to be a valid utf8 string
converted to tuples during serialization/deserialization.
"""
from __future__ import annotations
from collections.abc import Iterator
import datetime
import enum
import logging
import socket
import sys
import threading
from typing import Any
from typing_extensions import Self
import msgpack
@@ -52,16 +58,16 @@ class PrivsepTimeout(Exception):
class Serializer:
def __init__(self, writesock):
def __init__(self, writesock: socket.socket) -> None:
self.writesock = writesock
def send(self, msg):
def send(self, msg: Any) -> None:
buf = msgpack.packb(
msg, use_bin_type=True, unicode_errors='surrogateescape'
)
self.writesock.sendall(buf)
def close(self):
def close(self) -> None:
# Hilarious. `socket._socketobject.close()` doesn't actually
# call `self._sock.close()`. Oh well, we really wanted a half
# close anyway.
@@ -69,7 +75,7 @@ class Serializer:
class Deserializer:
def __init__(self, readsock):
def __init__(self, readsock: socket.socket) -> None:
self.readsock = readsock
self.unpacker = msgpack.Unpacker(
use_list=False,
@@ -81,10 +87,10 @@ class Deserializer:
max_buffer_size=100 * 1024 * 1024,
)
def __iter__(self):
def __iter__(self) -> Self:
return self
def __next__(self):
def __next__(self) -> Any:
while True:
try:
return next(self.unpacker)
@@ -101,23 +107,25 @@ class Deserializer:
class Future:
"""A very simple object to track the return of a function call"""
def __init__(self, lock, timeout=None):
def __init__(
self, lock: threading.Lock, timeout: float | None = None
) -> None:
self.condvar = threading.Condition(lock)
self.error = None
self.data = None
self.error: BaseException | None = None
self.data: Any = None
self.timeout = timeout
def set_result(self, data):
def set_result(self, data: Any) -> None:
"""Must already be holding lock used in constructor"""
self.data = data
self.condvar.notify()
def set_exception(self, exc):
def set_exception(self, exc: BaseException) -> None:
"""Must already be holding lock used in constructor"""
self.error = exc
self.condvar.notify()
def result(self):
def result(self) -> Any:
"""Must already be holding lock used in constructor"""
before = datetime.datetime.now()
if not self.condvar.wait(timeout=self.timeout):
@@ -139,7 +147,7 @@ class Future:
class ClientChannel:
def __init__(self, sock):
def __init__(self, sock: socket.socket) -> None:
self.running = False
self.writer = Serializer(sock)
self.lock = threading.Lock()
@@ -149,11 +157,11 @@ class ClientChannel:
args=(Deserializer(sock),),
)
self.reader_thread.daemon = True
self.outstanding_msgs = {}
self.outstanding_msgs: dict[str, Future] = {}
self.reader_thread.start()
def _reader_main(self, reader):
def _reader_main(self, reader: Deserializer) -> None:
"""This thread owns and demuxes the read channel"""
with self.lock:
self.running = True
@@ -183,11 +191,11 @@ class ClientChannel:
mbox.set_exception(exc)
self.running = False
def out_of_band(self, msg):
def out_of_band(self, msg: Any) -> None:
"""Received OOB message. Subclasses might want to override this."""
pass
def send_recv(self, msg, timeout=None):
def send_recv(self, msg: Any, timeout: float | None = None) -> Any:
myid = uuidutils.generate_uuid()
while myid in self.outstanding_msgs:
LOG.warning("myid shoudn't be in outstanding_msgs.")
@@ -208,7 +216,7 @@ class ClientChannel:
return reply
def close(self):
def close(self) -> None:
with self.lock:
self.writer.close()
@@ -218,19 +226,19 @@ class ClientChannel:
class ServerChannel:
"""Server-side twin to ClientChannel"""
def __init__(self, sock):
def __init__(self, sock: socket.socket) -> None:
self.rlock = threading.Lock()
self.reader_iter = iter(Deserializer(sock))
self.reader_iter: Iterator[Any] = iter(Deserializer(sock))
self.wlock = threading.Lock()
self.writer = Serializer(sock)
def __iter__(self):
def __iter__(self) -> Self:
return self
def __next__(self):
def __next__(self) -> Any:
with self.rlock:
return next(self.reader_iter)
def send(self, msg):
def send(self, msg: Any) -> None:
with self.wlock:
self.writer.send(msg)

View File

@@ -43,6 +43,10 @@ The privsep daemon exits when the communication channel is closed,
"""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Iterable
from concurrent import futures
import enum
import errno
@@ -55,6 +59,8 @@ import sys
import tempfile
import threading
import traceback
from typing import Any
from typing import TYPE_CHECKING
import debtcollector
import eventlet
@@ -68,6 +74,9 @@ from oslo_privsep._i18n import _
from oslo_privsep import capabilities
from oslo_privsep import comm
if TYPE_CHECKING:
from oslo_privsep import priv_context
if platform.system() == 'Linux':
import fcntl
import grp
@@ -76,7 +85,7 @@ if platform.system() == 'Linux':
LOG = logging.getLogger(__name__)
EVENTLET_MODULES = (
EVENTLET_MODULES: tuple[str, ...] = (
'os',
'select',
'socket',
@@ -86,10 +95,10 @@ EVENTLET_MODULES = (
'builtins',
'subprocess',
)
EVENTLET_LIBRARIES = []
EVENTLET_LIBRARIES: list[tuple[str, Any]] = []
def _null():
def _null() -> list[Any]:
return []
@@ -133,18 +142,18 @@ class ProtocolError(Exception):
pass
def set_cloexec(fd):
def set_cloexec(fd: int | socket.socket) -> None:
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
if (flags & fcntl.FD_CLOEXEC) == 0:
flags |= fcntl.FD_CLOEXEC
fcntl.fcntl(fd, fcntl.F_SETFD, flags)
def setuid(user_id_or_name):
def setuid(user_id_or_name: str | int) -> None:
try:
new_uid = int(user_id_or_name)
except (TypeError, ValueError):
new_uid = pwd.getpwnam(user_id_or_name).pw_uid
new_uid = pwd.getpwnam(str(user_id_or_name)).pw_uid
if new_uid != 0:
try:
os.setuid(new_uid)
@@ -154,11 +163,11 @@ def setuid(user_id_or_name):
raise FailedToDropPrivileges(msg)
def setgid(group_id_or_name):
def setgid(group_id_or_name: str | int) -> None:
try:
new_gid = int(group_id_or_name)
except (TypeError, ValueError):
new_gid = grp.getgrnam(group_id_or_name).gr_gid
new_gid = grp.getgrnam(str(group_id_or_name)).gr_gid
if new_gid != 0:
try:
os.setgid(new_gid)
@@ -169,12 +178,16 @@ def setgid(group_id_or_name):
class PrivsepLogHandler(pylogging.Handler):
def __init__(self, channel, processName=None):
def __init__(
self,
channel: comm.ServerChannel,
processName: str | None = None,
) -> None:
super().__init__()
self.channel = channel
self.processName = processName
def emit(self, record):
def emit(self, record: pylogging.LogRecord) -> None:
# Vaguely based on pylogging.handlers.SocketHandler.makePickle
if self.processName:
@@ -198,13 +211,15 @@ class PrivsepLogHandler(pylogging.Handler):
class _ClientChannel(comm.ClientChannel):
"""Our protocol, layered on the basic primitives in comm.ClientChannel"""
def __init__(self, sock, context):
def __init__(
self, sock: socket.socket, context: priv_context.PrivContext
) -> None:
self.log = logging.getLogger(context.conf.logger_name)
self.log_traceback = context.conf.log_daemon_traceback
self.log_traceback: bool = context.conf.log_daemon_traceback
super().__init__(sock)
self.exchange_ping()
def exchange_ping(self):
def exchange_ping(self) -> None:
try:
# exchange "ready" messages
reply = self.send_recv((comm.Message.PING.value,))
@@ -219,7 +234,13 @@ class _ClientChannel(comm.ClientChannel):
self.log.critical(msg)
raise FailedToDropPrivileges(msg)
def remote_call(self, name, args, kwargs, timeout):
def remote_call(
self,
name: str,
args: tuple[Any, ...],
kwargs: dict[str, Any],
timeout: float | None,
) -> Any:
result = self.send_recv(
(comm.Message.CALL.value, name, args, kwargs), timeout
)
@@ -242,7 +263,7 @@ class _ClientChannel(comm.ClientChannel):
else:
raise ProtocolError(_('Unexpected response: %r') % result)
def out_of_band(self, msg):
def out_of_band(self, msg: Any) -> None:
if msg[0] == comm.Message.LOG:
# (LOG, LogRecord __dict__)
message = {
@@ -258,7 +279,7 @@ class _ClientChannel(comm.ClientChannel):
)
def fdopen(fd, *args, **kwargs):
def fdopen(fd: int, *args: Any, **kwargs: Any) -> Any:
# NOTE(gus): We can't just use os.fdopen() here and allow the
# regular (optional) monkey_patching to do its thing. Turns out
# that regular file objects (as returned by os.fdopen) on python2
@@ -271,13 +292,13 @@ def fdopen(fd, *args, **kwargs):
return open(fd, *args, **kwargs)
def _fd_logger(level=logging.WARN):
def _fd_logger(level: int = logging.WARN) -> Any:
"""Helper that returns a file object that is asynchronously logged"""
read_fd, write_fd = os.pipe()
read_end = fdopen(read_fd, 'r', 1)
write_end = fdopen(write_fd, 'w', 1)
def logger(f):
def logger(f: Any) -> None:
for line in f:
LOG.log(level, 'privsep log: %s', line.rstrip())
@@ -288,7 +309,10 @@ def _fd_logger(level=logging.WARN):
return write_end
def replace_logging(handler, log_root=None):
def replace_logging(
handler: pylogging.Handler,
log_root: pylogging.Logger | None = None,
) -> None:
if log_root is None:
log_root = logging.getLogger(None).logger # root logger
for h in log_root.handlers:
@@ -296,7 +320,7 @@ def replace_logging(handler, log_root=None):
log_root.addHandler(handler)
def un_monkey_patch():
def un_monkey_patch() -> None:
for eventlet_mod_name, func_modules in EVENTLET_LIBRARIES:
if not eventlet.patcher.is_monkey_patched(eventlet_mod_name):
continue
@@ -312,7 +336,7 @@ def un_monkey_patch():
class ForkingClientChannel(_ClientChannel):
def __init__(self, context):
def __init__(self, context: priv_context.PrivContext) -> None:
"""Start privsep daemon using fork()
Assumes we already have required privileges.
@@ -353,7 +377,7 @@ class ForkingClientChannel(_ClientChannel):
class RootwrapClientChannel(_ClientChannel):
def __init__(self, context):
def __init__(self, context: priv_context.PrivContext) -> None:
"""Start privsep daemon using exec()
Uses sudo/rootwrap to gain privileges.
@@ -409,18 +433,22 @@ class RootwrapClientChannel(_ClientChannel):
class Daemon:
"""NB: This doesn't fork() - do that yourself before calling run()"""
def __init__(self, channel, context):
def __init__(
self,
channel: comm.ServerChannel,
context: priv_context.PrivContext,
) -> None:
self.channel = channel
self.context = context
self.user = context.conf.user
self.group = context.conf.group
self.caps = set(context.conf.capabilities)
self.user: str | int | None = context.conf.user
self.group: str | int | None = context.conf.group
self.caps: set[int] = set(context.conf.capabilities)
self.thread_pool = futures.ThreadPoolExecutor(
context.conf.thread_pool_size
)
self.communication_error = None
self.communication_error: BaseException | None = None
def run(self):
def run(self) -> None:
"""Run request loop. Sets up environment, then calls loop()"""
os.chdir("/")
os.umask(0)
@@ -429,13 +457,13 @@ class Daemon:
self.loop()
def _close_stdio(self):
def _close_stdio(self) -> None:
with open(os.devnull, 'w+') as devnull:
os.dup2(devnull.fileno(), StdioFd.STDIN)
os.dup2(devnull.fileno(), StdioFd.STDOUT)
# stderr is left untouched
def _drop_privs(self):
def _drop_privs(self) -> None:
try:
# Keep current capabilities across setuid away from root.
capabilities.set_keepcaps(True)
@@ -462,7 +490,7 @@ class Daemon:
capabilities.drop_all_caps_except(self.caps, self.caps, [])
def fmt_caps(capset):
def fmt_caps(capset: Iterable[int]) -> str:
if not capset:
return 'none'
fc = [capabilities.CAPS_BYVALUE.get(c, str(c)) for c in capset]
@@ -480,7 +508,9 @@ class Daemon:
},
)
def _process_cmd(self, msgid, cmd, *args):
def _process_cmd(
self, msgid: str, cmd: comm.Message, *args: Any
) -> tuple[Any, ...]:
"""Executes the requested command in an execution thread.
This executes a call within a thread executor and returns the results
@@ -523,7 +553,9 @@ class Daemon:
traceback.format_exc(),
)
def _create_done_callback(self, msgid):
def _create_done_callback(
self, msgid: str
) -> Callable[[futures.Future[tuple[Any, ...]]], None]:
"""Creates a future callback to receive command execution results.
:param msgid: The message identifier.
@@ -531,7 +563,7 @@ class Daemon:
"""
channel = self.channel
def _call_back(result):
def _call_back(result: futures.Future[tuple[Any, ...]]) -> None:
"""Future execution callback.
:param result: The `future` execution and its results.
@@ -544,7 +576,7 @@ class Daemon:
)
channel.send((msgid, reply))
except OSError:
self.communication_error = sys.exc_info()
self.communication_error = sys.exc_info()[1]
except Exception as e:
LOG.debug(
'privsep: Exception during request[%(msgid)s]: %(err)s',
@@ -566,7 +598,7 @@ class Daemon:
return _call_back
def loop(self):
def loop(self) -> None:
"""Main body of daemon request loop"""
LOG.info('privsep daemon running as pid %s', os.getpid())
@@ -577,7 +609,7 @@ class Daemon:
for msgid, msg in self.channel:
error = self.communication_error
if error:
if error.errno == errno.EPIPE:
if getattr(error, 'errno', None) == errno.EPIPE:
# Write stream closed, exit loop
break
raise error
@@ -589,7 +621,7 @@ class Daemon:
LOG.debug('Socket closed, shutting down privsep daemon')
def helper_main():
def helper_main() -> None:
"""Start privileged process, serving requests over a Unix socket."""
cfg.CONF.register_cli_opts(
@@ -606,9 +638,9 @@ def helper_main():
logging.setup(cfg.CONF, 'privsep', fix_eventlet=False)
context = importutils.import_class(cfg.CONF.privsep_context)
from oslo_privsep import priv_context # Avoid circular import
from oslo_privsep import priv_context as priv_context_mod # Avoid circular
if not isinstance(context, priv_context.PrivContext):
if not isinstance(context, priv_context_mod.PrivContext):
LOG.fatal(
'--privsep_context must be the (python) name of a '
'PrivContext object'

View File

@@ -12,6 +12,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Iterable
import copy
import enum
import functools
@@ -19,6 +23,7 @@ import logging
import multiprocessing
import shlex
import threading
from typing import Any
from oslo_config import cfg
from oslo_config import types
@@ -32,12 +37,12 @@ from oslo_privsep import daemon
LOG = logging.getLogger(__name__)
def CapNameOrInt(value):
value = str(value).strip()
def CapNameOrInt(value: str | int) -> int:
value_str = str(value).strip()
try:
return capabilities.CAPS_BYNAME[value]
return capabilities.CAPS_BYNAME[value_str]
except KeyError:
return int(value)
return int(value_str)
OPTS = [
@@ -96,7 +101,7 @@ _ENTRYPOINT_ATTR = 'privsep_entrypoint'
_HELPER_COMMAND_PREFIX = ['sudo']
def _list_opts():
def _list_opts() -> list[tuple[cfg.OptGroup, list[cfg.Opt]]]:
"""Returns a list of oslo.config options available in the library.
The returned list includes all oslo.config options which may be registered
@@ -130,7 +135,7 @@ class Method(enum.Enum):
ROOTWRAP = 2
def init(root_helper=None):
def init(root_helper: list[str] | None = None) -> None:
"""Initialise oslo.privsep library.
This function should be called at the top of main(), after the
@@ -151,13 +156,13 @@ def init(root_helper=None):
class PrivContext:
def __init__(
self,
prefix,
cfg_section='privsep',
pypath=None,
capabilities=None,
logger_name='oslo_privsep.daemon',
timeout=None,
):
prefix: str,
cfg_section: str = 'privsep',
pypath: str | None = None,
capabilities: Iterable[int] | None = None,
logger_name: str = 'oslo_privsep.daemon',
timeout: float | None = None,
) -> None:
# Note that capabilities=[] means retaining no capabilities
# and leaves even uid=0 with no powers except being able to
# read/write to the filesystem as uid=0. This might be what
@@ -173,12 +178,12 @@ class PrivContext:
self.cfg_section = cfg_section
self.client_mode = True
self.channel = None
self.channel: daemon._ClientChannel | None = None
self.start_lock = threading.Lock()
cfg.CONF.register_opts(OPTS, group=cfg_section)
cfg.CONF.set_default(
'capabilities', group=cfg_section, default=capabilities
'capabilities', group=cfg_section, default=list(capabilities)
)
cfg.CONF.set_default(
'logger_name', group=cfg_section, default=logger_name
@@ -186,16 +191,16 @@ class PrivContext:
self.timeout = timeout
@property
def conf(self):
def conf(self) -> Any:
"""Return the oslo.config section object as lazily as possible."""
# Need to avoid looking this up before oslo_config has been
# properly initialized.
return cfg.CONF[self.cfg_section]
def __repr__(self):
def __repr__(self) -> str:
return f'PrivContext(cfg_section={self.cfg_section})'
def helper_command(self, sockpath):
def helper_command(self, sockpath: str) -> list[str]:
# We need to be able to reconstruct the context object in the new
# python process we'll get after rootwrap/sudo. This means we
# need to construct the context object and store it somewhere
@@ -243,19 +248,21 @@ class PrivContext:
return cmd
def set_client_mode(self, enabled):
def set_client_mode(self, enabled: bool) -> None:
self.client_mode = enabled
def entrypoint(self, func):
def entrypoint(self, func: Callable[..., Any]) -> functools.partial[Any]:
"""This is intended to be used as a decorator."""
return self._entrypoint(func)
def entrypoint_with_timeout(self, timeout):
def entrypoint_with_timeout(
self, timeout: float
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""This is intended to be used as a decorator with timeout."""
def wrap(func):
def wrap(func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(func)
def inner(*args, **kwargs):
def inner(*args: Any, **kwargs: Any) -> Any:
f = self._entrypoint(func)
return f(*args, _wrap_timeout=timeout, **kwargs)
@@ -264,7 +271,7 @@ class PrivContext:
return wrap
def _entrypoint(self, func):
def _entrypoint(self, func: Callable[..., Any]) -> functools.partial[Any]:
if not func.__module__.startswith(self.prefix):
raise AssertionError(
f'{self!r} entrypoints must be below "{self.prefix}"'
@@ -284,10 +291,16 @@ class PrivContext:
setattr(f, _ENTRYPOINT_ATTR, self)
return f
def is_entrypoint(self, func):
def is_entrypoint(self, func: Callable[..., Any]) -> bool:
return getattr(func, _ENTRYPOINT_ATTR, None) is self
def _wrap(self, func, *args, _wrap_timeout=None, **kwargs):
def _wrap(
self,
func: Callable[..., Any],
*args: Any,
_wrap_timeout: float | None = None,
**kwargs: Any,
) -> Any:
if self.client_mode:
name = f'{func.__module__}.{func.__name__}'
if self.channel is not None and not self.channel.running:
@@ -295,17 +308,21 @@ class PrivContext:
self.stop()
if self.channel is None:
self.start()
if self.channel is None:
# narrow type: this will always be non-None thank to the above
raise RuntimeError('channel is not initialized')
r_call_timeout = _wrap_timeout or self.timeout
return self.channel.remote_call(name, args, kwargs, r_call_timeout)
else:
return func(*args, **kwargs)
def start(self, method=Method.ROOTWRAP):
def start(self, method: Method = Method.ROOTWRAP) -> None:
with self.start_lock:
if self.channel is not None:
LOG.warning('privsep daemon already running')
return
channel: daemon._ClientChannel
if method is Method.ROOTWRAP:
channel = daemon.RootwrapClientChannel(context=self)
elif method is Method.FORK:
@@ -315,7 +332,7 @@ class PrivContext:
self.channel = channel
def stop(self):
def stop(self) -> None:
if self.channel is not None:
self.channel.close()
self.channel = None

0
oslo_privsep/py.typed Normal file
View File

View File

@@ -15,6 +15,7 @@
import logging
import os
import sys
from typing import Any
import fixtures
from oslo_config import fixture as cfg_fixture
@@ -25,11 +26,15 @@ LOG = logging.getLogger(__name__)
class UnprivilegedPrivsepFixture(fixtures.Fixture):
def __init__(self, context, config_override):
def __init__(
self,
context: priv_context.PrivContext,
config_override: dict[str, Any],
) -> None:
self.context = context
self.config_override = config_override
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.conf = self.useFixture(cfg_fixture.Config()).conf

View File

@@ -46,8 +46,8 @@ class TestSerialization(base.BaseTestCase):
sock = BufSock()
self.input = comm.Serializer(sock)
self.output = iter(comm.Deserializer(sock))
self.input = comm.Serializer(sock) # type: ignore[arg-type]
self.output = iter(comm.Deserializer(sock)) # type: ignore[arg-type]
def send(self, data):
self.input.send(data)

View File

@@ -55,7 +55,7 @@ def get_fake_context(conf_attrs=None, **context_attrs):
capabilities.CAP_NET_ADMIN,
]
context.conf.logger_name = 'oslo_privsep.daemon'
vars(context).update(context_attrs)
vars(context).update(context_attrs) # type: ignore[attr-defined]
vars(context.conf).update(conf_attrs)
return context
@@ -109,15 +109,16 @@ class LogTest(testctx.TestContextTestCase):
self.assertIn('test@WARN', logger.output)
def test_record_data(self):
logs = []
logs: list[pylogging.LogRecord] = []
# fixtures.FakeLogger accepts only a formatter class/function, not an
# instance :(
fmt = functools.partial(LogRecorder, logs)
self.useFixture(
fixtures.FakeLogger(
level=logging.INFO,
format='dummy',
# fixtures.FakeLogger accepts only a formatter
# class/function, not an instance :(
formatter=functools.partial(LogRecorder, logs),
formatter=fmt, # type: ignore[arg-type]
)
)
@@ -142,15 +143,16 @@ class LogTest(testctx.TestContextTestCase):
self.assertEqual('logme', record.funcName)
def test_format_record(self):
logs = []
logs: list[pylogging.LogRecord] = []
# fixtures.FakeLogger accepts only a formatter class/function, not an
# instance :(
fmt = functools.partial(LogRecorder, logs)
self.useFixture(
fixtures.FakeLogger(
level=logging.INFO,
format='dummy',
# fixtures.FakeLogger accepts only a formatter
# class/function, not an instance :(
formatter=functools.partial(LogRecorder, logs),
formatter=fmt, # type: ignore[arg-type]
)
)
@@ -181,14 +183,16 @@ class LogTestDaemonTraceback(testctx.TestContextTestCase):
self.privsep_conf.set_override(
'log_daemon_traceback', True, group='privsep'
)
logs = []
logs: list[pylogging.LogRecord] = []
# fixtures.FakeLogger accepts only a formatter class/function, not an
# instance :(
fmt = functools.partial(LogRecorder, logs)
self.useFixture(
fixtures.FakeLogger(
level=logging.INFO,
format='dummy',
# fixtures.FakeLogger accepts only a formatter
# class/function, not an instance :(
formatter=functools.partial(LogRecorder, logs),
formatter=fmt, # type: ignore[arg-type]
)
)
@@ -303,7 +307,7 @@ class ClientChannelTestCase(base.BaseTestCase):
self.client_channel.out_of_band([comm.Message.PING])
mock_warning.assert_called_once()
@mock.patch.object(daemon.logging, 'getLogger')
@mock.patch.object(daemon.logging, 'getLogger') # type: ignore[attr-defined]
@mock.patch.object(pylogging, 'makeLogRecord')
def test_out_of_band_log_message_context_logger(
self, make_log_mock, get_logger_mock

View File

@@ -142,7 +142,7 @@ class PrivContextTest(testctx.TestContextTestCase):
def test_start_acquires_lock(self):
context = priv_context.PrivContext('test', capabilities=[])
context.channel = "something not None"
context.channel = "something not None" # type: ignore[assignment]
context.start_lock = mock.Mock()
context.start_lock.__enter__ = mock.Mock()
context.start_lock.__exit__ = mock.Mock()

View File

@@ -24,6 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: Implementation :: CPython",
"Typing :: Typed",
]
dynamic = ["version", "dependencies"]
@@ -49,3 +50,17 @@ docstring-code-format = true
[tool.ruff.lint]
select = ["E4", "E5", "E7", "E9", "F", "G", "LOG", "S", "UP"]
[tool.mypy]
python_version = "3.10"
show_column_numbers = true
show_error_context = true
strict = true
disable_error_code = ["import-untyped"]
exclude = '(?x)(doc | releasenotes)'
[[tool.mypy.overrides]]
module = ["oslo_privsep.tests.*", "oslo_privsep.functional.*"]
disallow_untyped_calls = false
disallow_untyped_defs = false
disallow_subclassing_any = false

17
tox.ini
View File

@@ -5,8 +5,8 @@ envlist = py3,pypy,pep8
[testenv]
deps =
-c{env:TOX_CONSTRAINTS_FILE:https://releases.openstack.org/constraints/upper/master}
-r{toxinidir}/test-requirements.txt
-r{toxinidir}/requirements.txt
-r{toxinidir}/test-requirements.txt
commands = stestr run --slowest {posargs}
[testenv:functional]
@@ -51,11 +51,23 @@ commands =
sphinx-build -a -E -W -d releasenotes/build/doctrees --keep-going -b html releasenotes/source releasenotes/build/html
[testenv:pep8]
skip_install = true
description =
Run style checks.
deps =
pre-commit>=2.6.0 # MIT
{[testenv:mypy]deps}
commands =
pre-commit run -a
{[testenv:mypy]commands}
[testenv:mypy]
description =
Run type checks.
deps =
{[testenv]deps}
mypy
commands =
mypy --cache-dir="{envdir}/mypy_cache" {posargs:oslo_privsep}
[flake8]
# We only enable the hacking (H) checks
@@ -67,3 +79,4 @@ exclude = .venv,.git,.tox,dist,doc,*lib/python*,*egg,build
[hacking]
import_exceptions =
oslo_privsep._i18n
typing