Add typing
Change-Id: I4b79a9bc3daaa1109c1bac9a135e8edfbef41ede Signed-off-by: Stephen Finucane <stephenfin@redhat.com>
This commit is contained in:
@@ -12,10 +12,14 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
import enum
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import cffi
|
||||
|
||||
@@ -69,8 +73,54 @@ class Capabilities(enum.IntEnum):
|
||||
CAP_CHECKPOINT_RESTORE = 40
|
||||
|
||||
|
||||
CAPS_BYNAME = {}
|
||||
CAPS_BYVALUE = {}
|
||||
if TYPE_CHECKING:
|
||||
# this is a bit long-winded, but it's the easiest way to expose these
|
||||
# attributes via mypy
|
||||
CAP_CHOWN = 0
|
||||
CAP_DAC_OVERRIDE = 1
|
||||
CAP_DAC_READ_SEARCH = 2
|
||||
CAP_FOWNER = 3
|
||||
CAP_FSETID = 4
|
||||
CAP_KILL = 5
|
||||
CAP_SETGID = 6
|
||||
CAP_SETUID = 7
|
||||
CAP_SETPCAP = 8
|
||||
CAP_LINUX_IMMUTABLE = 9
|
||||
CAP_NET_BIND_SERVICE = 10
|
||||
CAP_NET_BROADCAST = 11
|
||||
CAP_NET_ADMIN = 12
|
||||
CAP_NET_RAW = 13
|
||||
CAP_IPC_LOCK = 14
|
||||
CAP_IPC_OWNER = 15
|
||||
CAP_SYS_MODULE = 16
|
||||
CAP_SYS_RAWIO = 17
|
||||
CAP_SYS_CHROOT = 18
|
||||
CAP_SYS_PTRACE = 19
|
||||
CAP_SYS_PACCT = 20
|
||||
CAP_SYS_ADMIN = 21
|
||||
CAP_SYS_BOOT = 22
|
||||
CAP_SYS_NICE = 23
|
||||
CAP_SYS_RESOURCE = 24
|
||||
CAP_SYS_TIME = 25
|
||||
CAP_SYS_TTY_CONFIG = 26
|
||||
CAP_MKNOD = 27
|
||||
CAP_LEASE = 28
|
||||
CAP_AUDIT_WRITE = 29
|
||||
CAP_AUDIT_CONTROL = 30
|
||||
CAP_SETFCAP = 31
|
||||
CAP_MAC_OVERRIDE = 32
|
||||
CAP_MAC_ADMIN = 33
|
||||
CAP_SYSLOG = 34
|
||||
CAP_WAKE_ALARM = 35
|
||||
CAP_BLOCK_SUSPEND = 36
|
||||
CAP_AUDIT_READ = 37
|
||||
CAP_PERFMON = 38
|
||||
CAP_BPF = 39
|
||||
CAP_CHECKPOINT_RESTORE = 40
|
||||
|
||||
|
||||
CAPS_BYNAME: dict[str, int] = {}
|
||||
CAPS_BYVALUE: dict[int, str] = {}
|
||||
module = sys.modules[__name__]
|
||||
# Convenience dicts for human readable values
|
||||
# module attributes for backwards compat/convenience
|
||||
@@ -114,6 +164,11 @@ ffi = cffi.FFI()
|
||||
ffi.cdef(CDEF)
|
||||
|
||||
|
||||
crt: Any
|
||||
_prctl: Any
|
||||
_capget: Any
|
||||
_capset: Any
|
||||
|
||||
if platform.system() == 'Linux':
|
||||
# mock.patching crt.* directly seems to upset cffi. Use an
|
||||
# indirection point here for easier testing.
|
||||
@@ -127,7 +182,7 @@ else:
|
||||
_capset = None
|
||||
|
||||
|
||||
def set_keepcaps(enable):
|
||||
def set_keepcaps(enable: bool) -> None:
|
||||
"""Set/unset thread's "keep capabilities" flag - see prctl(2)"""
|
||||
ret = _prctl(crt.PR_SET_KEEPCAPS, ffi.cast('unsigned long', bool(enable)))
|
||||
if ret != 0:
|
||||
@@ -135,7 +190,11 @@ def set_keepcaps(enable):
|
||||
raise OSError(errno, os.strerror(errno))
|
||||
|
||||
|
||||
def drop_all_caps_except(effective, permitted, inheritable):
|
||||
def drop_all_caps_except(
|
||||
effective: Iterable[int],
|
||||
permitted: Iterable[int],
|
||||
inheritable: Iterable[int],
|
||||
) -> None:
|
||||
"""Set (effective, permitted, inheritable) to provided list of caps"""
|
||||
eff = _caps_to_mask(effective)
|
||||
prm = _caps_to_mask(permitted)
|
||||
@@ -159,12 +218,12 @@ def drop_all_caps_except(effective, permitted, inheritable):
|
||||
raise OSError(errno, os.strerror(errno))
|
||||
|
||||
|
||||
def _mask_to_caps(mask):
|
||||
def _mask_to_caps(mask: int) -> list[int]:
|
||||
"""Convert bitmask to list of set bit offsets"""
|
||||
return [i for i in range(64) if (1 << i) & mask]
|
||||
|
||||
|
||||
def _caps_to_mask(caps):
|
||||
def _caps_to_mask(caps: Iterable[int]) -> int:
|
||||
"""Convert list of bit offsets to bitmask"""
|
||||
mask = 0
|
||||
for cap in caps:
|
||||
@@ -172,7 +231,7 @@ def _caps_to_mask(caps):
|
||||
return mask
|
||||
|
||||
|
||||
def get_caps():
|
||||
def get_caps() -> tuple[list[int], list[int], list[int]]:
|
||||
"""Return (effective, permitted, inheritable) as lists of caps"""
|
||||
header = ffi.new(
|
||||
'cap_user_header_t',
|
||||
|
||||
@@ -20,12 +20,18 @@ python datatypes. Msgpack 'raw' is assumed to be a valid utf8 string
|
||||
converted to tuples during serialization/deserialization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
import datetime
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
import msgpack
|
||||
|
||||
@@ -52,16 +58,16 @@ class PrivsepTimeout(Exception):
|
||||
|
||||
|
||||
class Serializer:
|
||||
def __init__(self, writesock):
|
||||
def __init__(self, writesock: socket.socket) -> None:
|
||||
self.writesock = writesock
|
||||
|
||||
def send(self, msg):
|
||||
def send(self, msg: Any) -> None:
|
||||
buf = msgpack.packb(
|
||||
msg, use_bin_type=True, unicode_errors='surrogateescape'
|
||||
)
|
||||
self.writesock.sendall(buf)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
# Hilarious. `socket._socketobject.close()` doesn't actually
|
||||
# call `self._sock.close()`. Oh well, we really wanted a half
|
||||
# close anyway.
|
||||
@@ -69,7 +75,7 @@ class Serializer:
|
||||
|
||||
|
||||
class Deserializer:
|
||||
def __init__(self, readsock):
|
||||
def __init__(self, readsock: socket.socket) -> None:
|
||||
self.readsock = readsock
|
||||
self.unpacker = msgpack.Unpacker(
|
||||
use_list=False,
|
||||
@@ -81,10 +87,10 @@ class Deserializer:
|
||||
max_buffer_size=100 * 1024 * 1024,
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> Any:
|
||||
while True:
|
||||
try:
|
||||
return next(self.unpacker)
|
||||
@@ -101,23 +107,25 @@ class Deserializer:
|
||||
class Future:
|
||||
"""A very simple object to track the return of a function call"""
|
||||
|
||||
def __init__(self, lock, timeout=None):
|
||||
def __init__(
|
||||
self, lock: threading.Lock, timeout: float | None = None
|
||||
) -> None:
|
||||
self.condvar = threading.Condition(lock)
|
||||
self.error = None
|
||||
self.data = None
|
||||
self.error: BaseException | None = None
|
||||
self.data: Any = None
|
||||
self.timeout = timeout
|
||||
|
||||
def set_result(self, data):
|
||||
def set_result(self, data: Any) -> None:
|
||||
"""Must already be holding lock used in constructor"""
|
||||
self.data = data
|
||||
self.condvar.notify()
|
||||
|
||||
def set_exception(self, exc):
|
||||
def set_exception(self, exc: BaseException) -> None:
|
||||
"""Must already be holding lock used in constructor"""
|
||||
self.error = exc
|
||||
self.condvar.notify()
|
||||
|
||||
def result(self):
|
||||
def result(self) -> Any:
|
||||
"""Must already be holding lock used in constructor"""
|
||||
before = datetime.datetime.now()
|
||||
if not self.condvar.wait(timeout=self.timeout):
|
||||
@@ -139,7 +147,7 @@ class Future:
|
||||
|
||||
|
||||
class ClientChannel:
|
||||
def __init__(self, sock):
|
||||
def __init__(self, sock: socket.socket) -> None:
|
||||
self.running = False
|
||||
self.writer = Serializer(sock)
|
||||
self.lock = threading.Lock()
|
||||
@@ -149,11 +157,11 @@ class ClientChannel:
|
||||
args=(Deserializer(sock),),
|
||||
)
|
||||
self.reader_thread.daemon = True
|
||||
self.outstanding_msgs = {}
|
||||
self.outstanding_msgs: dict[str, Future] = {}
|
||||
|
||||
self.reader_thread.start()
|
||||
|
||||
def _reader_main(self, reader):
|
||||
def _reader_main(self, reader: Deserializer) -> None:
|
||||
"""This thread owns and demuxes the read channel"""
|
||||
with self.lock:
|
||||
self.running = True
|
||||
@@ -183,11 +191,11 @@ class ClientChannel:
|
||||
mbox.set_exception(exc)
|
||||
self.running = False
|
||||
|
||||
def out_of_band(self, msg):
|
||||
def out_of_band(self, msg: Any) -> None:
|
||||
"""Received OOB message. Subclasses might want to override this."""
|
||||
pass
|
||||
|
||||
def send_recv(self, msg, timeout=None):
|
||||
def send_recv(self, msg: Any, timeout: float | None = None) -> Any:
|
||||
myid = uuidutils.generate_uuid()
|
||||
while myid in self.outstanding_msgs:
|
||||
LOG.warning("myid shoudn't be in outstanding_msgs.")
|
||||
@@ -208,7 +216,7 @@ class ClientChannel:
|
||||
|
||||
return reply
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
with self.lock:
|
||||
self.writer.close()
|
||||
|
||||
@@ -218,19 +226,19 @@ class ClientChannel:
|
||||
class ServerChannel:
|
||||
"""Server-side twin to ClientChannel"""
|
||||
|
||||
def __init__(self, sock):
|
||||
def __init__(self, sock: socket.socket) -> None:
|
||||
self.rlock = threading.Lock()
|
||||
self.reader_iter = iter(Deserializer(sock))
|
||||
self.reader_iter: Iterator[Any] = iter(Deserializer(sock))
|
||||
self.wlock = threading.Lock()
|
||||
self.writer = Serializer(sock)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> Any:
|
||||
with self.rlock:
|
||||
return next(self.reader_iter)
|
||||
|
||||
def send(self, msg):
|
||||
def send(self, msg: Any) -> None:
|
||||
with self.wlock:
|
||||
self.writer.send(msg)
|
||||
|
||||
@@ -43,6 +43,10 @@ The privsep daemon exits when the communication channel is closed,
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from concurrent import futures
|
||||
import enum
|
||||
import errno
|
||||
@@ -55,6 +59,8 @@ import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import debtcollector
|
||||
import eventlet
|
||||
@@ -68,6 +74,9 @@ from oslo_privsep._i18n import _
|
||||
from oslo_privsep import capabilities
|
||||
from oslo_privsep import comm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from oslo_privsep import priv_context
|
||||
|
||||
if platform.system() == 'Linux':
|
||||
import fcntl
|
||||
import grp
|
||||
@@ -76,7 +85,7 @@ if platform.system() == 'Linux':
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
EVENTLET_MODULES = (
|
||||
EVENTLET_MODULES: tuple[str, ...] = (
|
||||
'os',
|
||||
'select',
|
||||
'socket',
|
||||
@@ -86,10 +95,10 @@ EVENTLET_MODULES = (
|
||||
'builtins',
|
||||
'subprocess',
|
||||
)
|
||||
EVENTLET_LIBRARIES = []
|
||||
EVENTLET_LIBRARIES: list[tuple[str, Any]] = []
|
||||
|
||||
|
||||
def _null():
|
||||
def _null() -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
@@ -133,18 +142,18 @@ class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def set_cloexec(fd):
|
||||
def set_cloexec(fd: int | socket.socket) -> None:
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
|
||||
if (flags & fcntl.FD_CLOEXEC) == 0:
|
||||
flags |= fcntl.FD_CLOEXEC
|
||||
fcntl.fcntl(fd, fcntl.F_SETFD, flags)
|
||||
|
||||
|
||||
def setuid(user_id_or_name):
|
||||
def setuid(user_id_or_name: str | int) -> None:
|
||||
try:
|
||||
new_uid = int(user_id_or_name)
|
||||
except (TypeError, ValueError):
|
||||
new_uid = pwd.getpwnam(user_id_or_name).pw_uid
|
||||
new_uid = pwd.getpwnam(str(user_id_or_name)).pw_uid
|
||||
if new_uid != 0:
|
||||
try:
|
||||
os.setuid(new_uid)
|
||||
@@ -154,11 +163,11 @@ def setuid(user_id_or_name):
|
||||
raise FailedToDropPrivileges(msg)
|
||||
|
||||
|
||||
def setgid(group_id_or_name):
|
||||
def setgid(group_id_or_name: str | int) -> None:
|
||||
try:
|
||||
new_gid = int(group_id_or_name)
|
||||
except (TypeError, ValueError):
|
||||
new_gid = grp.getgrnam(group_id_or_name).gr_gid
|
||||
new_gid = grp.getgrnam(str(group_id_or_name)).gr_gid
|
||||
if new_gid != 0:
|
||||
try:
|
||||
os.setgid(new_gid)
|
||||
@@ -169,12 +178,16 @@ def setgid(group_id_or_name):
|
||||
|
||||
|
||||
class PrivsepLogHandler(pylogging.Handler):
|
||||
def __init__(self, channel, processName=None):
|
||||
def __init__(
|
||||
self,
|
||||
channel: comm.ServerChannel,
|
||||
processName: str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
self.processName = processName
|
||||
|
||||
def emit(self, record):
|
||||
def emit(self, record: pylogging.LogRecord) -> None:
|
||||
# Vaguely based on pylogging.handlers.SocketHandler.makePickle
|
||||
|
||||
if self.processName:
|
||||
@@ -198,13 +211,15 @@ class PrivsepLogHandler(pylogging.Handler):
|
||||
class _ClientChannel(comm.ClientChannel):
|
||||
"""Our protocol, layered on the basic primitives in comm.ClientChannel"""
|
||||
|
||||
def __init__(self, sock, context):
|
||||
def __init__(
|
||||
self, sock: socket.socket, context: priv_context.PrivContext
|
||||
) -> None:
|
||||
self.log = logging.getLogger(context.conf.logger_name)
|
||||
self.log_traceback = context.conf.log_daemon_traceback
|
||||
self.log_traceback: bool = context.conf.log_daemon_traceback
|
||||
super().__init__(sock)
|
||||
self.exchange_ping()
|
||||
|
||||
def exchange_ping(self):
|
||||
def exchange_ping(self) -> None:
|
||||
try:
|
||||
# exchange "ready" messages
|
||||
reply = self.send_recv((comm.Message.PING.value,))
|
||||
@@ -219,7 +234,13 @@ class _ClientChannel(comm.ClientChannel):
|
||||
self.log.critical(msg)
|
||||
raise FailedToDropPrivileges(msg)
|
||||
|
||||
def remote_call(self, name, args, kwargs, timeout):
|
||||
def remote_call(
|
||||
self,
|
||||
name: str,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
result = self.send_recv(
|
||||
(comm.Message.CALL.value, name, args, kwargs), timeout
|
||||
)
|
||||
@@ -242,7 +263,7 @@ class _ClientChannel(comm.ClientChannel):
|
||||
else:
|
||||
raise ProtocolError(_('Unexpected response: %r') % result)
|
||||
|
||||
def out_of_band(self, msg):
|
||||
def out_of_band(self, msg: Any) -> None:
|
||||
if msg[0] == comm.Message.LOG:
|
||||
# (LOG, LogRecord __dict__)
|
||||
message = {
|
||||
@@ -258,7 +279,7 @@ class _ClientChannel(comm.ClientChannel):
|
||||
)
|
||||
|
||||
|
||||
def fdopen(fd, *args, **kwargs):
|
||||
def fdopen(fd: int, *args: Any, **kwargs: Any) -> Any:
|
||||
# NOTE(gus): We can't just use os.fdopen() here and allow the
|
||||
# regular (optional) monkey_patching to do its thing. Turns out
|
||||
# that regular file objects (as returned by os.fdopen) on python2
|
||||
@@ -271,13 +292,13 @@ def fdopen(fd, *args, **kwargs):
|
||||
return open(fd, *args, **kwargs)
|
||||
|
||||
|
||||
def _fd_logger(level=logging.WARN):
|
||||
def _fd_logger(level: int = logging.WARN) -> Any:
|
||||
"""Helper that returns a file object that is asynchronously logged"""
|
||||
read_fd, write_fd = os.pipe()
|
||||
read_end = fdopen(read_fd, 'r', 1)
|
||||
write_end = fdopen(write_fd, 'w', 1)
|
||||
|
||||
def logger(f):
|
||||
def logger(f: Any) -> None:
|
||||
for line in f:
|
||||
LOG.log(level, 'privsep log: %s', line.rstrip())
|
||||
|
||||
@@ -288,7 +309,10 @@ def _fd_logger(level=logging.WARN):
|
||||
return write_end
|
||||
|
||||
|
||||
def replace_logging(handler, log_root=None):
|
||||
def replace_logging(
|
||||
handler: pylogging.Handler,
|
||||
log_root: pylogging.Logger | None = None,
|
||||
) -> None:
|
||||
if log_root is None:
|
||||
log_root = logging.getLogger(None).logger # root logger
|
||||
for h in log_root.handlers:
|
||||
@@ -296,7 +320,7 @@ def replace_logging(handler, log_root=None):
|
||||
log_root.addHandler(handler)
|
||||
|
||||
|
||||
def un_monkey_patch():
|
||||
def un_monkey_patch() -> None:
|
||||
for eventlet_mod_name, func_modules in EVENTLET_LIBRARIES:
|
||||
if not eventlet.patcher.is_monkey_patched(eventlet_mod_name):
|
||||
continue
|
||||
@@ -312,7 +336,7 @@ def un_monkey_patch():
|
||||
|
||||
|
||||
class ForkingClientChannel(_ClientChannel):
|
||||
def __init__(self, context):
|
||||
def __init__(self, context: priv_context.PrivContext) -> None:
|
||||
"""Start privsep daemon using fork()
|
||||
|
||||
Assumes we already have required privileges.
|
||||
@@ -353,7 +377,7 @@ class ForkingClientChannel(_ClientChannel):
|
||||
|
||||
|
||||
class RootwrapClientChannel(_ClientChannel):
|
||||
def __init__(self, context):
|
||||
def __init__(self, context: priv_context.PrivContext) -> None:
|
||||
"""Start privsep daemon using exec()
|
||||
|
||||
Uses sudo/rootwrap to gain privileges.
|
||||
@@ -409,18 +433,22 @@ class RootwrapClientChannel(_ClientChannel):
|
||||
class Daemon:
|
||||
"""NB: This doesn't fork() - do that yourself before calling run()"""
|
||||
|
||||
def __init__(self, channel, context):
|
||||
def __init__(
|
||||
self,
|
||||
channel: comm.ServerChannel,
|
||||
context: priv_context.PrivContext,
|
||||
) -> None:
|
||||
self.channel = channel
|
||||
self.context = context
|
||||
self.user = context.conf.user
|
||||
self.group = context.conf.group
|
||||
self.caps = set(context.conf.capabilities)
|
||||
self.user: str | int | None = context.conf.user
|
||||
self.group: str | int | None = context.conf.group
|
||||
self.caps: set[int] = set(context.conf.capabilities)
|
||||
self.thread_pool = futures.ThreadPoolExecutor(
|
||||
context.conf.thread_pool_size
|
||||
)
|
||||
self.communication_error = None
|
||||
self.communication_error: BaseException | None = None
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
"""Run request loop. Sets up environment, then calls loop()"""
|
||||
os.chdir("/")
|
||||
os.umask(0)
|
||||
@@ -429,13 +457,13 @@ class Daemon:
|
||||
|
||||
self.loop()
|
||||
|
||||
def _close_stdio(self):
|
||||
def _close_stdio(self) -> None:
|
||||
with open(os.devnull, 'w+') as devnull:
|
||||
os.dup2(devnull.fileno(), StdioFd.STDIN)
|
||||
os.dup2(devnull.fileno(), StdioFd.STDOUT)
|
||||
# stderr is left untouched
|
||||
|
||||
def _drop_privs(self):
|
||||
def _drop_privs(self) -> None:
|
||||
try:
|
||||
# Keep current capabilities across setuid away from root.
|
||||
capabilities.set_keepcaps(True)
|
||||
@@ -462,7 +490,7 @@ class Daemon:
|
||||
|
||||
capabilities.drop_all_caps_except(self.caps, self.caps, [])
|
||||
|
||||
def fmt_caps(capset):
|
||||
def fmt_caps(capset: Iterable[int]) -> str:
|
||||
if not capset:
|
||||
return 'none'
|
||||
fc = [capabilities.CAPS_BYVALUE.get(c, str(c)) for c in capset]
|
||||
@@ -480,7 +508,9 @@ class Daemon:
|
||||
},
|
||||
)
|
||||
|
||||
def _process_cmd(self, msgid, cmd, *args):
|
||||
def _process_cmd(
|
||||
self, msgid: str, cmd: comm.Message, *args: Any
|
||||
) -> tuple[Any, ...]:
|
||||
"""Executes the requested command in an execution thread.
|
||||
|
||||
This executes a call within a thread executor and returns the results
|
||||
@@ -523,7 +553,9 @@ class Daemon:
|
||||
traceback.format_exc(),
|
||||
)
|
||||
|
||||
def _create_done_callback(self, msgid):
|
||||
def _create_done_callback(
|
||||
self, msgid: str
|
||||
) -> Callable[[futures.Future[tuple[Any, ...]]], None]:
|
||||
"""Creates a future callback to receive command execution results.
|
||||
|
||||
:param msgid: The message identifier.
|
||||
@@ -531,7 +563,7 @@ class Daemon:
|
||||
"""
|
||||
channel = self.channel
|
||||
|
||||
def _call_back(result):
|
||||
def _call_back(result: futures.Future[tuple[Any, ...]]) -> None:
|
||||
"""Future execution callback.
|
||||
|
||||
:param result: The `future` execution and its results.
|
||||
@@ -544,7 +576,7 @@ class Daemon:
|
||||
)
|
||||
channel.send((msgid, reply))
|
||||
except OSError:
|
||||
self.communication_error = sys.exc_info()
|
||||
self.communication_error = sys.exc_info()[1]
|
||||
except Exception as e:
|
||||
LOG.debug(
|
||||
'privsep: Exception during request[%(msgid)s]: %(err)s',
|
||||
@@ -566,7 +598,7 @@ class Daemon:
|
||||
|
||||
return _call_back
|
||||
|
||||
def loop(self):
|
||||
def loop(self) -> None:
|
||||
"""Main body of daemon request loop"""
|
||||
LOG.info('privsep daemon running as pid %s', os.getpid())
|
||||
|
||||
@@ -577,7 +609,7 @@ class Daemon:
|
||||
for msgid, msg in self.channel:
|
||||
error = self.communication_error
|
||||
if error:
|
||||
if error.errno == errno.EPIPE:
|
||||
if getattr(error, 'errno', None) == errno.EPIPE:
|
||||
# Write stream closed, exit loop
|
||||
break
|
||||
raise error
|
||||
@@ -589,7 +621,7 @@ class Daemon:
|
||||
LOG.debug('Socket closed, shutting down privsep daemon')
|
||||
|
||||
|
||||
def helper_main():
|
||||
def helper_main() -> None:
|
||||
"""Start privileged process, serving requests over a Unix socket."""
|
||||
|
||||
cfg.CONF.register_cli_opts(
|
||||
@@ -606,9 +638,9 @@ def helper_main():
|
||||
logging.setup(cfg.CONF, 'privsep', fix_eventlet=False)
|
||||
|
||||
context = importutils.import_class(cfg.CONF.privsep_context)
|
||||
from oslo_privsep import priv_context # Avoid circular import
|
||||
from oslo_privsep import priv_context as priv_context_mod # Avoid circular
|
||||
|
||||
if not isinstance(context, priv_context.PrivContext):
|
||||
if not isinstance(context, priv_context_mod.PrivContext):
|
||||
LOG.fatal(
|
||||
'--privsep_context must be the (python) name of a '
|
||||
'PrivContext object'
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
import copy
|
||||
import enum
|
||||
import functools
|
||||
@@ -19,6 +23,7 @@ import logging
|
||||
import multiprocessing
|
||||
import shlex
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_config import types
|
||||
@@ -32,12 +37,12 @@ from oslo_privsep import daemon
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def CapNameOrInt(value):
|
||||
value = str(value).strip()
|
||||
def CapNameOrInt(value: str | int) -> int:
|
||||
value_str = str(value).strip()
|
||||
try:
|
||||
return capabilities.CAPS_BYNAME[value]
|
||||
return capabilities.CAPS_BYNAME[value_str]
|
||||
except KeyError:
|
||||
return int(value)
|
||||
return int(value_str)
|
||||
|
||||
|
||||
OPTS = [
|
||||
@@ -96,7 +101,7 @@ _ENTRYPOINT_ATTR = 'privsep_entrypoint'
|
||||
_HELPER_COMMAND_PREFIX = ['sudo']
|
||||
|
||||
|
||||
def _list_opts():
|
||||
def _list_opts() -> list[tuple[cfg.OptGroup, list[cfg.Opt]]]:
|
||||
"""Returns a list of oslo.config options available in the library.
|
||||
|
||||
The returned list includes all oslo.config options which may be registered
|
||||
@@ -130,7 +135,7 @@ class Method(enum.Enum):
|
||||
ROOTWRAP = 2
|
||||
|
||||
|
||||
def init(root_helper=None):
|
||||
def init(root_helper: list[str] | None = None) -> None:
|
||||
"""Initialise oslo.privsep library.
|
||||
|
||||
This function should be called at the top of main(), after the
|
||||
@@ -151,13 +156,13 @@ def init(root_helper=None):
|
||||
class PrivContext:
|
||||
def __init__(
|
||||
self,
|
||||
prefix,
|
||||
cfg_section='privsep',
|
||||
pypath=None,
|
||||
capabilities=None,
|
||||
logger_name='oslo_privsep.daemon',
|
||||
timeout=None,
|
||||
):
|
||||
prefix: str,
|
||||
cfg_section: str = 'privsep',
|
||||
pypath: str | None = None,
|
||||
capabilities: Iterable[int] | None = None,
|
||||
logger_name: str = 'oslo_privsep.daemon',
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
# Note that capabilities=[] means retaining no capabilities
|
||||
# and leaves even uid=0 with no powers except being able to
|
||||
# read/write to the filesystem as uid=0. This might be what
|
||||
@@ -173,12 +178,12 @@ class PrivContext:
|
||||
self.cfg_section = cfg_section
|
||||
|
||||
self.client_mode = True
|
||||
self.channel = None
|
||||
self.channel: daemon._ClientChannel | None = None
|
||||
self.start_lock = threading.Lock()
|
||||
|
||||
cfg.CONF.register_opts(OPTS, group=cfg_section)
|
||||
cfg.CONF.set_default(
|
||||
'capabilities', group=cfg_section, default=capabilities
|
||||
'capabilities', group=cfg_section, default=list(capabilities)
|
||||
)
|
||||
cfg.CONF.set_default(
|
||||
'logger_name', group=cfg_section, default=logger_name
|
||||
@@ -186,16 +191,16 @@ class PrivContext:
|
||||
self.timeout = timeout
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
def conf(self) -> Any:
|
||||
"""Return the oslo.config section object as lazily as possible."""
|
||||
# Need to avoid looking this up before oslo_config has been
|
||||
# properly initialized.
|
||||
return cfg.CONF[self.cfg_section]
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'PrivContext(cfg_section={self.cfg_section})'
|
||||
|
||||
def helper_command(self, sockpath):
|
||||
def helper_command(self, sockpath: str) -> list[str]:
|
||||
# We need to be able to reconstruct the context object in the new
|
||||
# python process we'll get after rootwrap/sudo. This means we
|
||||
# need to construct the context object and store it somewhere
|
||||
@@ -243,19 +248,21 @@ class PrivContext:
|
||||
|
||||
return cmd
|
||||
|
||||
def set_client_mode(self, enabled):
|
||||
def set_client_mode(self, enabled: bool) -> None:
|
||||
self.client_mode = enabled
|
||||
|
||||
def entrypoint(self, func):
|
||||
def entrypoint(self, func: Callable[..., Any]) -> functools.partial[Any]:
|
||||
"""This is intended to be used as a decorator."""
|
||||
return self._entrypoint(func)
|
||||
|
||||
def entrypoint_with_timeout(self, timeout):
|
||||
def entrypoint_with_timeout(
|
||||
self, timeout: float
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""This is intended to be used as a decorator with timeout."""
|
||||
|
||||
def wrap(func):
|
||||
def wrap(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
def inner(*args: Any, **kwargs: Any) -> Any:
|
||||
f = self._entrypoint(func)
|
||||
return f(*args, _wrap_timeout=timeout, **kwargs)
|
||||
|
||||
@@ -264,7 +271,7 @@ class PrivContext:
|
||||
|
||||
return wrap
|
||||
|
||||
def _entrypoint(self, func):
|
||||
def _entrypoint(self, func: Callable[..., Any]) -> functools.partial[Any]:
|
||||
if not func.__module__.startswith(self.prefix):
|
||||
raise AssertionError(
|
||||
f'{self!r} entrypoints must be below "{self.prefix}"'
|
||||
@@ -284,10 +291,16 @@ class PrivContext:
|
||||
setattr(f, _ENTRYPOINT_ATTR, self)
|
||||
return f
|
||||
|
||||
def is_entrypoint(self, func):
|
||||
def is_entrypoint(self, func: Callable[..., Any]) -> bool:
|
||||
return getattr(func, _ENTRYPOINT_ATTR, None) is self
|
||||
|
||||
def _wrap(self, func, *args, _wrap_timeout=None, **kwargs):
|
||||
def _wrap(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
*args: Any,
|
||||
_wrap_timeout: float | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if self.client_mode:
|
||||
name = f'{func.__module__}.{func.__name__}'
|
||||
if self.channel is not None and not self.channel.running:
|
||||
@@ -295,17 +308,21 @@ class PrivContext:
|
||||
self.stop()
|
||||
if self.channel is None:
|
||||
self.start()
|
||||
if self.channel is None:
|
||||
# narrow type: this will always be non-None thank to the above
|
||||
raise RuntimeError('channel is not initialized')
|
||||
r_call_timeout = _wrap_timeout or self.timeout
|
||||
return self.channel.remote_call(name, args, kwargs, r_call_timeout)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def start(self, method=Method.ROOTWRAP):
|
||||
def start(self, method: Method = Method.ROOTWRAP) -> None:
|
||||
with self.start_lock:
|
||||
if self.channel is not None:
|
||||
LOG.warning('privsep daemon already running')
|
||||
return
|
||||
|
||||
channel: daemon._ClientChannel
|
||||
if method is Method.ROOTWRAP:
|
||||
channel = daemon.RootwrapClientChannel(context=self)
|
||||
elif method is Method.FORK:
|
||||
@@ -315,7 +332,7 @@ class PrivContext:
|
||||
|
||||
self.channel = channel
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
if self.channel is not None:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
|
||||
0
oslo_privsep/py.typed
Normal file
0
oslo_privsep/py.typed
Normal file
@@ -15,6 +15,7 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import fixtures
|
||||
from oslo_config import fixture as cfg_fixture
|
||||
@@ -25,11 +26,15 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnprivilegedPrivsepFixture(fixtures.Fixture):
|
||||
def __init__(self, context, config_override):
|
||||
def __init__(
|
||||
self,
|
||||
context: priv_context.PrivContext,
|
||||
config_override: dict[str, Any],
|
||||
) -> None:
|
||||
self.context = context
|
||||
self.config_override = config_override
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
self.conf = self.useFixture(cfg_fixture.Config()).conf
|
||||
|
||||
@@ -46,8 +46,8 @@ class TestSerialization(base.BaseTestCase):
|
||||
|
||||
sock = BufSock()
|
||||
|
||||
self.input = comm.Serializer(sock)
|
||||
self.output = iter(comm.Deserializer(sock))
|
||||
self.input = comm.Serializer(sock) # type: ignore[arg-type]
|
||||
self.output = iter(comm.Deserializer(sock)) # type: ignore[arg-type]
|
||||
|
||||
def send(self, data):
|
||||
self.input.send(data)
|
||||
|
||||
@@ -55,7 +55,7 @@ def get_fake_context(conf_attrs=None, **context_attrs):
|
||||
capabilities.CAP_NET_ADMIN,
|
||||
]
|
||||
context.conf.logger_name = 'oslo_privsep.daemon'
|
||||
vars(context).update(context_attrs)
|
||||
vars(context).update(context_attrs) # type: ignore[attr-defined]
|
||||
vars(context.conf).update(conf_attrs)
|
||||
return context
|
||||
|
||||
@@ -109,15 +109,16 @@ class LogTest(testctx.TestContextTestCase):
|
||||
self.assertIn('test@WARN', logger.output)
|
||||
|
||||
def test_record_data(self):
|
||||
logs = []
|
||||
logs: list[pylogging.LogRecord] = []
|
||||
# fixtures.FakeLogger accepts only a formatter class/function, not an
|
||||
# instance :(
|
||||
fmt = functools.partial(LogRecorder, logs)
|
||||
|
||||
self.useFixture(
|
||||
fixtures.FakeLogger(
|
||||
level=logging.INFO,
|
||||
format='dummy',
|
||||
# fixtures.FakeLogger accepts only a formatter
|
||||
# class/function, not an instance :(
|
||||
formatter=functools.partial(LogRecorder, logs),
|
||||
formatter=fmt, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -142,15 +143,16 @@ class LogTest(testctx.TestContextTestCase):
|
||||
self.assertEqual('logme', record.funcName)
|
||||
|
||||
def test_format_record(self):
|
||||
logs = []
|
||||
logs: list[pylogging.LogRecord] = []
|
||||
# fixtures.FakeLogger accepts only a formatter class/function, not an
|
||||
# instance :(
|
||||
fmt = functools.partial(LogRecorder, logs)
|
||||
|
||||
self.useFixture(
|
||||
fixtures.FakeLogger(
|
||||
level=logging.INFO,
|
||||
format='dummy',
|
||||
# fixtures.FakeLogger accepts only a formatter
|
||||
# class/function, not an instance :(
|
||||
formatter=functools.partial(LogRecorder, logs),
|
||||
formatter=fmt, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -181,14 +183,16 @@ class LogTestDaemonTraceback(testctx.TestContextTestCase):
|
||||
self.privsep_conf.set_override(
|
||||
'log_daemon_traceback', True, group='privsep'
|
||||
)
|
||||
logs = []
|
||||
logs: list[pylogging.LogRecord] = []
|
||||
# fixtures.FakeLogger accepts only a formatter class/function, not an
|
||||
# instance :(
|
||||
fmt = functools.partial(LogRecorder, logs)
|
||||
|
||||
self.useFixture(
|
||||
fixtures.FakeLogger(
|
||||
level=logging.INFO,
|
||||
format='dummy',
|
||||
# fixtures.FakeLogger accepts only a formatter
|
||||
# class/function, not an instance :(
|
||||
formatter=functools.partial(LogRecorder, logs),
|
||||
formatter=fmt, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -303,7 +307,7 @@ class ClientChannelTestCase(base.BaseTestCase):
|
||||
self.client_channel.out_of_band([comm.Message.PING])
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
@mock.patch.object(daemon.logging, 'getLogger')
|
||||
@mock.patch.object(daemon.logging, 'getLogger') # type: ignore[attr-defined]
|
||||
@mock.patch.object(pylogging, 'makeLogRecord')
|
||||
def test_out_of_band_log_message_context_logger(
|
||||
self, make_log_mock, get_logger_mock
|
||||
|
||||
@@ -142,7 +142,7 @@ class PrivContextTest(testctx.TestContextTestCase):
|
||||
|
||||
def test_start_acquires_lock(self):
|
||||
context = priv_context.PrivContext('test', capabilities=[])
|
||||
context.channel = "something not None"
|
||||
context.channel = "something not None" # type: ignore[assignment]
|
||||
context.start_lock = mock.Mock()
|
||||
context.start_lock.__enter__ = mock.Mock()
|
||||
context.start_lock.__exit__ = mock.Mock()
|
||||
|
||||
@@ -24,6 +24,7 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dynamic = ["version", "dependencies"]
|
||||
|
||||
@@ -49,3 +50,17 @@ docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E5", "E7", "E9", "F", "G", "LOG", "S", "UP"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
show_column_numbers = true
|
||||
show_error_context = true
|
||||
strict = true
|
||||
disable_error_code = ["import-untyped"]
|
||||
exclude = '(?x)(doc | releasenotes)'
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["oslo_privsep.tests.*", "oslo_privsep.functional.*"]
|
||||
disallow_untyped_calls = false
|
||||
disallow_untyped_defs = false
|
||||
disallow_subclassing_any = false
|
||||
|
||||
17
tox.ini
17
tox.ini
@@ -5,8 +5,8 @@ envlist = py3,pypy,pep8
|
||||
[testenv]
|
||||
deps =
|
||||
-c{env:TOX_CONSTRAINTS_FILE:https://releases.openstack.org/constraints/upper/master}
|
||||
-r{toxinidir}/test-requirements.txt
|
||||
-r{toxinidir}/requirements.txt
|
||||
-r{toxinidir}/test-requirements.txt
|
||||
commands = stestr run --slowest {posargs}
|
||||
|
||||
[testenv:functional]
|
||||
@@ -51,11 +51,23 @@ commands =
|
||||
sphinx-build -a -E -W -d releasenotes/build/doctrees --keep-going -b html releasenotes/source releasenotes/build/html
|
||||
|
||||
[testenv:pep8]
|
||||
skip_install = true
|
||||
description =
|
||||
Run style checks.
|
||||
deps =
|
||||
pre-commit>=2.6.0 # MIT
|
||||
{[testenv:mypy]deps}
|
||||
commands =
|
||||
pre-commit run -a
|
||||
{[testenv:mypy]commands}
|
||||
|
||||
[testenv:mypy]
|
||||
description =
|
||||
Run type checks.
|
||||
deps =
|
||||
{[testenv]deps}
|
||||
mypy
|
||||
commands =
|
||||
mypy --cache-dir="{envdir}/mypy_cache" {posargs:oslo_privsep}
|
||||
|
||||
[flake8]
|
||||
# We only enable the hacking (H) checks
|
||||
@@ -67,3 +79,4 @@ exclude = .venv,.git,.tox,dist,doc,*lib/python*,*egg,build
|
||||
[hacking]
|
||||
import_exceptions =
|
||||
oslo_privsep._i18n
|
||||
typing
|
||||
|
||||
Reference in New Issue
Block a user