Add typing
Signed-off-by: Stephen Finucane <stephenfin@redhat.com> Change-Id: Iaaf09440270650efc2f522837091b4ce9f791289
This commit is contained in:
@@ -17,15 +17,19 @@ import os
|
||||
|
||||
import debtcollector
|
||||
|
||||
__all__ = [
|
||||
'subprocess',
|
||||
]
|
||||
|
||||
try:
|
||||
import eventlet.patcher
|
||||
except ImportError:
|
||||
_patched_socket = False
|
||||
else:
|
||||
# In tests patching happens later, so we'll rely on environment variable
|
||||
_patched_socket = eventlet.patcher.is_monkey_patched(
|
||||
'socket'
|
||||
) or os.environ.get('TEST_EVENTLET', False)
|
||||
_patched_socket = eventlet.patcher.is_monkey_patched('socket') or bool(
|
||||
os.environ.get('TEST_EVENTLET', False)
|
||||
)
|
||||
|
||||
if not _patched_socket:
|
||||
import subprocess
|
||||
@@ -33,4 +37,4 @@ else:
|
||||
debtcollector.deprecate(
|
||||
"Eventlet support is deprecated and will be soon removed."
|
||||
)
|
||||
from eventlet.green import subprocess # noqa
|
||||
from eventlet.green import subprocess # type: ignore
|
||||
|
||||
@@ -13,10 +13,13 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import logging
|
||||
from multiprocessing import managers
|
||||
import subprocess as stdlib_subprocess
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
import weakref
|
||||
|
||||
import oslo_rootwrap
|
||||
@@ -30,8 +33,9 @@ if oslo_rootwrap._patched_socket:
|
||||
# https://bitbucket.org/eventlet/eventlet/pull-request/41
|
||||
# This check happens here instead of jsonrpc to avoid importing eventlet
|
||||
# from daemon code that is run with root privileges.
|
||||
jsonrpc.JsonConnection.recvall = jsonrpc.JsonConnection._recvall_slow
|
||||
|
||||
jsonrpc.JsonConnection.recvall = ( # type: ignore[method-assign]
|
||||
jsonrpc.JsonConnection._recvall_slow
|
||||
)
|
||||
|
||||
ClientManager = daemon.get_manager_class()
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -39,7 +43,17 @@ SHUTDOWN_RETRIES = 3
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, rootwrap_daemon_cmd):
|
||||
_start_command: Sequence[str]
|
||||
_initialized: bool
|
||||
_need_restart: bool
|
||||
_mutex: threading.Lock
|
||||
_manager: managers.BaseManager | None
|
||||
_proxy: Any # RootwrapClass proxy
|
||||
_process: stdlib_subprocess.Popen[bytes] | None
|
||||
_finalize: Callable[[], None] | None
|
||||
_exec_sem: threading.Lock
|
||||
|
||||
def __init__(self, rootwrap_daemon_cmd: Sequence[str]) -> None:
|
||||
self._start_command = rootwrap_daemon_cmd
|
||||
self._initialized = False
|
||||
self._need_restart = False
|
||||
@@ -53,7 +67,7 @@ class Client:
|
||||
# needed with the threading module.
|
||||
self._exec_sem = threading.Lock()
|
||||
|
||||
def _initialize(self):
|
||||
def _initialize(self) -> None:
|
||||
if self._process is not None and self._process.poll() is not None:
|
||||
LOG.warning(
|
||||
"Leaving behind already spawned process with pid %d, "
|
||||
@@ -73,19 +87,21 @@ class Client:
|
||||
)
|
||||
|
||||
self._process = process_obj
|
||||
assert process_obj.stdout is not None # narrow type
|
||||
assert process_obj.stderr is not None # narrow type
|
||||
socket_path = process_obj.stdout.readline()[:-1].decode('utf-8')
|
||||
authkey = process_obj.stdout.read(32)
|
||||
if process_obj.poll() is not None:
|
||||
stderr = process_obj.stderr.read()
|
||||
# NOTE(yorik-sar): don't expose stdout here
|
||||
raise Exception(
|
||||
f"Failed to spawn rootwrap process.\nstderr:\n{stderr}"
|
||||
f"Failed to spawn rootwrap process.\nstderr:\n{stderr!r}"
|
||||
)
|
||||
LOG.info(
|
||||
"Spawned new rootwrap daemon process with pid=%d", process_obj.pid
|
||||
)
|
||||
|
||||
def wait_process():
|
||||
def wait_process() -> None:
|
||||
return_code = process_obj.wait()
|
||||
LOG.info(
|
||||
"Rootwrap daemon process exit with status: %d", return_code
|
||||
@@ -96,14 +112,18 @@ class Client:
|
||||
reap_process.start()
|
||||
self._manager = ClientManager(socket_path, authkey)
|
||||
self._manager.connect()
|
||||
self._proxy = self._manager.rootwrap()
|
||||
self._proxy = self._manager.rootwrap() # type: ignore[attr-defined]
|
||||
self._finalize = weakref.finalize(
|
||||
self, self._shutdown, self._process, self._manager
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
@staticmethod
|
||||
def _shutdown(process, manager, JsonClient=jsonrpc.JsonClient):
|
||||
def _shutdown(
|
||||
process: stdlib_subprocess.Popen[bytes],
|
||||
manager: managers.BaseManager,
|
||||
JsonClient: type[jsonrpc.JsonClient] = jsonrpc.JsonClient,
|
||||
) -> None:
|
||||
# Storing JsonClient in arguments because globals are set to None
|
||||
# before executing atexit routines in Python 2.x
|
||||
if process.poll() is None:
|
||||
@@ -112,7 +132,7 @@ class Client:
|
||||
)
|
||||
for _ in range(SHUTDOWN_RETRIES):
|
||||
try:
|
||||
manager.rootwrap().shutdown()
|
||||
manager.rootwrap().shutdown() # type: ignore[attr-defined]
|
||||
break
|
||||
except (EOFError, OSError):
|
||||
break # assume it is dead already
|
||||
@@ -122,19 +142,20 @@ class Client:
|
||||
# can't provide sane timeout on 2.x and we most likely don't have
|
||||
# permisions to do so
|
||||
# Invalidate manager's state so that proxy won't try to do decref
|
||||
manager._state.value = managers.State.SHUTDOWN
|
||||
manager._state.value = managers.State.SHUTDOWN # type: ignore[attr-defined]
|
||||
|
||||
def _ensure_initialized(self):
|
||||
def _ensure_initialized(self) -> None:
|
||||
with self._mutex:
|
||||
if not self._initialized:
|
||||
self._initialize()
|
||||
|
||||
def _restart(self, proxy):
|
||||
def _restart(self, proxy: Any) -> Any:
|
||||
with self._mutex:
|
||||
if not self._initialized:
|
||||
raise AssertionError("Client should be initialized.")
|
||||
# Verify if someone has already restarted this.
|
||||
if self._proxy is proxy:
|
||||
assert self._finalize is not None # narrow type
|
||||
self._finalize()
|
||||
self._manager = None
|
||||
self._proxy = None
|
||||
@@ -143,7 +164,9 @@ class Client:
|
||||
self._need_restart = False
|
||||
return self._proxy
|
||||
|
||||
def _run_one_command(self, proxy, cmd, stdin):
|
||||
def _run_one_command(
|
||||
self, proxy: Any, cmd: list[str], stdin: str | None
|
||||
) -> tuple[int, str, str]:
|
||||
"""Wrap proxy.run_one_command, setting _need_restart on an exception.
|
||||
|
||||
Usually it should be enough to drain stale data on socket
|
||||
@@ -151,14 +174,16 @@ class Client:
|
||||
"""
|
||||
try:
|
||||
_need_restart = True
|
||||
res = proxy.run_one_command(cmd, stdin)
|
||||
res: tuple[int, str, str] = proxy.run_one_command(cmd, stdin)
|
||||
_need_restart = False
|
||||
return res
|
||||
finally:
|
||||
if _need_restart:
|
||||
self._need_restart = True
|
||||
|
||||
def execute(self, cmd, stdin=None):
|
||||
def execute(
|
||||
self, cmd: list[str], stdin: str | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
with self._exec_sem:
|
||||
self._ensure_initialized()
|
||||
proxy = self._proxy
|
||||
|
||||
@@ -34,10 +34,14 @@ import configparser
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import NoReturn
|
||||
|
||||
from oslo_rootwrap import filters as filters_mod
|
||||
from oslo_rootwrap import subprocess
|
||||
from oslo_rootwrap import wrapper
|
||||
|
||||
resource: ModuleType | None
|
||||
try:
|
||||
# This isn't available on all platforms (e.g. Windows).
|
||||
import resource
|
||||
@@ -53,18 +57,20 @@ RC_NOEXECFOUND = 96
|
||||
SIGNAL_BASE = 128
|
||||
|
||||
|
||||
def _exit_error(execname, message, errorcode, log=True):
|
||||
def _exit_error(
|
||||
execname: str, message: str, errorcode: int, log: bool = True
|
||||
) -> NoReturn:
|
||||
print(f"{execname}: {message}", file=sys.stderr)
|
||||
if log:
|
||||
LOG.error(message)
|
||||
sys.exit(errorcode)
|
||||
|
||||
|
||||
def daemon():
|
||||
return main(run_daemon=True)
|
||||
def daemon() -> None:
|
||||
main(run_daemon=True)
|
||||
|
||||
|
||||
def main(run_daemon=False):
|
||||
def main(run_daemon: bool = False) -> None:
|
||||
# Split arguments, require at least a command
|
||||
execname = sys.argv.pop(0)
|
||||
if run_daemon:
|
||||
@@ -147,7 +153,12 @@ def main(run_daemon=False):
|
||||
run_one_command(execname, config, filters, sys.argv)
|
||||
|
||||
|
||||
def run_one_command(execname, config, filters, userargs):
|
||||
def run_one_command(
|
||||
execname: str,
|
||||
config: wrapper.RootwrapConfig,
|
||||
filters: list[filters_mod.CommandFilter],
|
||||
userargs: list[str],
|
||||
) -> NoReturn:
|
||||
# Execute command if it matches any of the loaded filters
|
||||
try:
|
||||
obj = wrapper.start_subprocess(
|
||||
@@ -165,6 +176,7 @@ def run_one_command(execname, config, filters, userargs):
|
||||
returncode = SIGNAL_BASE - returncode
|
||||
sys.exit(returncode)
|
||||
except wrapper.FilterMatchNotExecutable as exc:
|
||||
assert exc.match is not None # narrow type
|
||||
msg = (
|
||||
f"Executable not found: {exc.match.exec_path} "
|
||||
f"(filter match = {exc.match.name})"
|
||||
|
||||
@@ -24,8 +24,10 @@ import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from types import FrameType
|
||||
|
||||
from oslo_rootwrap import cmd
|
||||
from oslo_rootwrap import filters as filters_mod
|
||||
from oslo_rootwrap import jsonrpc
|
||||
from oslo_rootwrap import subprocess
|
||||
from oslo_rootwrap import wrapper
|
||||
@@ -35,17 +37,30 @@ LOG = logging.getLogger(__name__)
|
||||
# Since multiprocessing supports only pickle and xmlrpclib for serialization of
|
||||
# RPC requests and responses, we declare another 'jsonrpc' serializer
|
||||
|
||||
managers.listener_client['jsonrpc'] = jsonrpc.JsonListener, jsonrpc.JsonClient
|
||||
managers.listener_client['jsonrpc'] = ( # type: ignore[attr-defined]
|
||||
jsonrpc.JsonListener,
|
||||
jsonrpc.JsonClient,
|
||||
)
|
||||
|
||||
|
||||
class RootwrapClass:
|
||||
def __init__(self, config, filters):
|
||||
last_called: float
|
||||
daemon_timeout: int
|
||||
timeout: threading.Timer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: wrapper.RootwrapConfig,
|
||||
filters: list[filters_mod.CommandFilter],
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.filters = filters
|
||||
self.reset_timer()
|
||||
self.prepare_timer(config)
|
||||
|
||||
def run_one_command(self, userargs, stdin=None):
|
||||
def run_one_command(
|
||||
self, userargs: list[str], stdin: str | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
self.reset_timer()
|
||||
try:
|
||||
obj = wrapper.start_subprocess(
|
||||
@@ -69,26 +84,29 @@ class RootwrapClass:
|
||||
)
|
||||
return cmd.RC_UNAUTHORIZED, "", ""
|
||||
|
||||
stdin_bytes: bytes | None = None
|
||||
if stdin is not None:
|
||||
stdin = os.fsencode(stdin)
|
||||
out, err = obj.communicate(stdin)
|
||||
out = os.fsdecode(out)
|
||||
err = os.fsdecode(err)
|
||||
return obj.returncode, out, err
|
||||
stdin_bytes = os.fsencode(stdin)
|
||||
out, err = obj.communicate(stdin_bytes)
|
||||
out_str = os.fsdecode(out)
|
||||
err_str = os.fsdecode(err)
|
||||
return obj.returncode or 0, out_str, err_str
|
||||
|
||||
@classmethod
|
||||
def reset_timer(cls):
|
||||
def reset_timer(cls) -> None:
|
||||
cls.last_called = time.time()
|
||||
|
||||
@classmethod
|
||||
def cancel_timer(cls):
|
||||
def cancel_timer(cls) -> None:
|
||||
try:
|
||||
cls.timeout.cancel()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def prepare_timer(cls, config=None):
|
||||
def prepare_timer(
|
||||
cls, config: wrapper.RootwrapConfig | None = None
|
||||
) -> None:
|
||||
if config is not None:
|
||||
cls.daemon_timeout = config.daemon_timeout
|
||||
# Wait a bit longer to avoid rounding errors
|
||||
@@ -102,25 +120,31 @@ class RootwrapClass:
|
||||
cls.timeout.start()
|
||||
|
||||
@classmethod
|
||||
def handle_timeout(cls):
|
||||
def handle_timeout(cls) -> None:
|
||||
if cls.last_called < time.time() - cls.daemon_timeout:
|
||||
cls.shutdown()
|
||||
|
||||
cls.prepare_timer()
|
||||
|
||||
@staticmethod
|
||||
def shutdown():
|
||||
def shutdown() -> None:
|
||||
# Suicide to force break of the main thread
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
|
||||
|
||||
def get_manager_class(config=None, filters=None):
|
||||
def get_manager_class(
|
||||
config: wrapper.RootwrapConfig | None = None,
|
||||
filters: list[filters_mod.CommandFilter] | None = None,
|
||||
) -> type[managers.BaseManager]:
|
||||
class RootwrapManager(managers.BaseManager):
|
||||
def __init__(self, address=None, authkey=None):
|
||||
def __init__(
|
||||
self, address: str | None = None, authkey: bytes | None = None
|
||||
) -> None:
|
||||
# Force jsonrpc because neither pickle nor xmlrpclib is secure
|
||||
super().__init__(address, authkey, serializer='jsonrpc')
|
||||
|
||||
if config is not None:
|
||||
assert filters is not None # narrow type
|
||||
partial_class = functools.partial(RootwrapClass, config, filters)
|
||||
RootwrapManager.register('rootwrap', partial_class)
|
||||
else:
|
||||
@@ -129,7 +153,10 @@ def get_manager_class(config=None, filters=None):
|
||||
return RootwrapManager
|
||||
|
||||
|
||||
def daemon_start(config, filters):
|
||||
def daemon_start(
|
||||
config: wrapper.RootwrapConfig,
|
||||
filters: list[filters_mod.CommandFilter],
|
||||
) -> None:
|
||||
temp_dir = tempfile.mkdtemp(prefix='rootwrap-')
|
||||
LOG.debug("Created temporary directory %s", temp_dir)
|
||||
try:
|
||||
@@ -160,7 +187,8 @@ def daemon_start(config, filters):
|
||||
os.chmod(socket_path, rw_rw_rw_)
|
||||
sys.stdout.buffer.write(socket_path.encode('utf-8'))
|
||||
sys.stdout.buffer.write(b'\n')
|
||||
sys.stdout.buffer.write(bytes(server.authkey))
|
||||
# type stubs don't expose this paramter currently
|
||||
sys.stdout.buffer.write(bytes(server.authkey)) # type: ignore[attr-defined]
|
||||
sys.stdin.close()
|
||||
sys.stdout.close()
|
||||
sys.stderr.close()
|
||||
@@ -171,7 +199,7 @@ def daemon_start(config, filters):
|
||||
LOG.info("Starting rootwrap daemon main loop")
|
||||
server.serve_forever()
|
||||
finally:
|
||||
conn = server.listener
|
||||
conn: jsonrpc.JsonListener = server.listener # type: ignore[attr-defined]
|
||||
# This will break accept() loop with EOFError if it was not in the
|
||||
# main thread (as in Python 3.x)
|
||||
conn.close()
|
||||
@@ -194,11 +222,13 @@ def daemon_start(config, filters):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def daemon_stop(server, signal, frame):
|
||||
LOG.info("Got signal %s. Shutting down server", signal)
|
||||
def daemon_stop(
|
||||
server: managers.Server, signum: int, frame: FrameType | None
|
||||
) -> None:
|
||||
LOG.info("Got signal %s. Shutting down server", signum)
|
||||
# Signals are caught in the main thread which means this handler will run
|
||||
# in the middle of serve_forever() loop. It will catch this exception and
|
||||
# properly return. Since all threads created by server_forever are
|
||||
# daemonic, we need to join them afterwards. In Python 3 we can just hit
|
||||
# stop_event instead.
|
||||
server.stop_event.set()
|
||||
server.stop_event.set() # type: ignore[attr-defined]
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -28,12 +29,12 @@ if sys.platform != 'win32':
|
||||
import pwd
|
||||
|
||||
|
||||
def _getuid(user):
|
||||
def _getuid(user: str) -> int:
|
||||
"""Return uid for user."""
|
||||
return pwd.getpwnam(user).pw_uid
|
||||
|
||||
|
||||
def realpath(path):
|
||||
def realpath(path: str) -> str:
|
||||
"""Return the real absolute path.
|
||||
|
||||
If the execution directory does not exist, os.getcwd() raises a
|
||||
@@ -49,14 +50,14 @@ def realpath(path):
|
||||
class CommandFilter:
|
||||
"""Command filter only checking that the 1st argument matches exec_path."""
|
||||
|
||||
def __init__(self, exec_path, run_as, *args):
|
||||
def __init__(self, exec_path: str, run_as: str, *args: str) -> None:
|
||||
self.name = ''
|
||||
self.exec_path = exec_path
|
||||
self.run_as = run_as
|
||||
self.args = args
|
||||
self.real_exec = None
|
||||
self.real_exec: str | None = None
|
||||
|
||||
def get_exec(self, exec_dirs=None):
|
||||
def get_exec(self, exec_dirs: list[str] | None = None) -> str | None:
|
||||
"""Returns existing executable, or empty string if none found."""
|
||||
exec_dirs = exec_dirs or []
|
||||
if self.real_exec is not None:
|
||||
@@ -72,7 +73,7 @@ class CommandFilter:
|
||||
break
|
||||
return self.real_exec
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
"""Only check that the first argument (command) matches exec_path."""
|
||||
if userargs:
|
||||
base_path_matches = os.path.basename(self.exec_path) == userargs[0]
|
||||
@@ -80,18 +81,20 @@ class CommandFilter:
|
||||
return exact_path_matches or base_path_matches
|
||||
return False
|
||||
|
||||
def preexec(self):
|
||||
def preexec(self) -> None:
|
||||
"""Setuid in subprocess right before command is invoked."""
|
||||
if self.run_as != 'root':
|
||||
os.setuid(_getuid(self.run_as))
|
||||
|
||||
def get_command(self, userargs, exec_dirs=None):
|
||||
def get_command(
|
||||
self, userargs: list[str], exec_dirs: list[str] | None = None
|
||||
) -> list[str]:
|
||||
"""Returns command to execute."""
|
||||
exec_dirs = exec_dirs or []
|
||||
to_exec = self.get_exec(exec_dirs=exec_dirs) or self.exec_path
|
||||
return [to_exec] + userargs[1:]
|
||||
return [to_exec] + list(userargs[1:])
|
||||
|
||||
def get_environment(self, userargs):
|
||||
def get_environment(self, userargs: list[str]) -> dict[str, str] | None:
|
||||
"""Returns specific environment to set, None if none."""
|
||||
return None
|
||||
|
||||
@@ -99,7 +102,7 @@ class CommandFilter:
|
||||
class RegExpFilter(CommandFilter):
|
||||
"""Command filter doing regexp matching for every argument."""
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
# Early skip if command or number of args don't match
|
||||
if not userargs or len(self.args) != len(userargs):
|
||||
# DENY: argument numbers don't match
|
||||
@@ -131,7 +134,7 @@ class PathFilter(CommandFilter):
|
||||
|
||||
"""
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
if not userargs or len(userargs) < 2:
|
||||
return False
|
||||
|
||||
@@ -157,7 +160,9 @@ class PathFilter(CommandFilter):
|
||||
and paths_are_within_base_dirs
|
||||
)
|
||||
|
||||
def get_command(self, userargs, exec_dirs=None):
|
||||
def get_command(
|
||||
self, userargs: list[str], exec_dirs: list[str] | None = None
|
||||
) -> list[str]:
|
||||
exec_dirs = exec_dirs or []
|
||||
command, arguments = userargs[0], userargs[1:]
|
||||
|
||||
@@ -182,14 +187,14 @@ class KillFilter(CommandFilter):
|
||||
executable, so it will only work on procfs-capable systems (not OSX).
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
def __init__(self, *args: str) -> None:
|
||||
super().__init__("/bin/kill", *args)
|
||||
|
||||
def _program(self, pid):
|
||||
def _program(self, pid: int) -> str | None:
|
||||
"""Determine the program associated with pid"""
|
||||
|
||||
try:
|
||||
command = os.readlink(f"/proc/{int(pid)}/exe")
|
||||
command = os.readlink(f"/proc/{pid}/exe")
|
||||
except (ValueError, OSError):
|
||||
# Incorrect PID
|
||||
return None
|
||||
@@ -211,10 +216,10 @@ class KillFilter(CommandFilter):
|
||||
# a ';......' or '.#prelink#......' suffix etc.
|
||||
# So defer to /proc/PID/cmdline in that case.
|
||||
try:
|
||||
with open(f"/proc/{int(pid)}/cmdline") as pfile:
|
||||
with open(f"/proc/{pid}/cmdline") as pfile:
|
||||
cmdline = pfile.read().partition('\0')[0]
|
||||
|
||||
cmdline = shutil.which(cmdline)
|
||||
cmdline = shutil.which(cmdline) or cmdline
|
||||
if os.path.isfile(cmdline):
|
||||
command = cmdline
|
||||
|
||||
@@ -225,7 +230,7 @@ class KillFilter(CommandFilter):
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
if not userargs or userargs[0] != "kill":
|
||||
return False
|
||||
args = list(userargs)
|
||||
@@ -243,7 +248,11 @@ class KillFilter(CommandFilter):
|
||||
# No signal requested, but filter requires specific signal
|
||||
return False
|
||||
|
||||
command = self._program(args[1])
|
||||
try:
|
||||
pid = int(args[1])
|
||||
except ValueError:
|
||||
return False
|
||||
command = self._program(pid)
|
||||
if not command:
|
||||
return False
|
||||
|
||||
@@ -263,18 +272,20 @@ class KillFilter(CommandFilter):
|
||||
class ReadFileFilter(CommandFilter):
|
||||
"""Specific filter for the utils.read_file_as_root call."""
|
||||
|
||||
def __init__(self, file_path, *args):
|
||||
def __init__(self, file_path: str, *args: str) -> None:
|
||||
self.file_path = file_path
|
||||
super().__init__("/bin/cat", "root", *args)
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
return userargs == ['cat', self.file_path]
|
||||
|
||||
|
||||
class IpFilter(CommandFilter):
|
||||
"""Specific filter for the ip utility to that does not match exec."""
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
if not userargs:
|
||||
return False
|
||||
if userargs[0] == 'ip':
|
||||
# Avoid the 'netns exec' command here
|
||||
for a, b in zip(userargs[1:], userargs[2:]):
|
||||
@@ -282,6 +293,7 @@ class IpFilter(CommandFilter):
|
||||
return b not in EXEC_VARS
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class EnvFilter(CommandFilter):
|
||||
@@ -291,7 +303,7 @@ class EnvFilter(CommandFilter):
|
||||
leading env A=B.. strings appropriately.
|
||||
"""
|
||||
|
||||
def _extract_env(self, arglist):
|
||||
def _extract_env(self, arglist: Sequence[str]) -> set[str]:
|
||||
"""Extract all leading NAME=VALUE arguments from arglist."""
|
||||
|
||||
envs = set()
|
||||
@@ -301,7 +313,7 @@ class EnvFilter(CommandFilter):
|
||||
envs.add(arg.partition('=')[0])
|
||||
return envs
|
||||
|
||||
def __init__(self, exec_path, run_as, *args):
|
||||
def __init__(self, exec_path: str, run_as: str, *args: str) -> None:
|
||||
super().__init__(exec_path, run_as, *args)
|
||||
|
||||
env_list = self._extract_env(self.args)
|
||||
@@ -310,7 +322,9 @@ class EnvFilter(CommandFilter):
|
||||
if "env" in exec_path and len(env_list) < len(self.args):
|
||||
self.exec_path = self.args[len(env_list)]
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
if not userargs:
|
||||
return False
|
||||
# ignore leading 'env'
|
||||
if userargs[0] == 'env':
|
||||
userargs.pop(0)
|
||||
@@ -325,13 +339,13 @@ class EnvFilter(CommandFilter):
|
||||
user_command = userargs[len(user_envs) : len(user_envs) + 1]
|
||||
|
||||
# match first non-env argument with CommandFilter
|
||||
return (
|
||||
return bool(
|
||||
super().match(user_command)
|
||||
and len(filter_envs)
|
||||
and user_envs == filter_envs
|
||||
)
|
||||
|
||||
def exec_args(self, userargs):
|
||||
def exec_args(self, userargs: list[str]) -> list[str]:
|
||||
args = userargs[:]
|
||||
|
||||
# ignore leading 'env'
|
||||
@@ -344,11 +358,13 @@ class EnvFilter(CommandFilter):
|
||||
|
||||
return args
|
||||
|
||||
def get_command(self, userargs, exec_dirs=[]):
|
||||
to_exec = self.get_exec(exec_dirs=exec_dirs) or self.exec_path
|
||||
def get_command(
|
||||
self, userargs: list[str], exec_dirs: list[str] | None = None
|
||||
) -> list[str]:
|
||||
to_exec = self.get_exec(exec_dirs=exec_dirs or []) or self.exec_path
|
||||
return [to_exec] + self.exec_args(userargs)[1:]
|
||||
|
||||
def get_environment(self, userargs):
|
||||
def get_environment(self, userargs: list[str]) -> dict[str, str] | None:
|
||||
env = os.environ.copy()
|
||||
|
||||
# ignore leading 'env'
|
||||
@@ -367,17 +383,17 @@ class EnvFilter(CommandFilter):
|
||||
|
||||
|
||||
class ChainingFilter(CommandFilter):
|
||||
def exec_args(self, userargs):
|
||||
def exec_args(self, userargs: list[str]) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
class IpNetnsExecFilter(ChainingFilter):
|
||||
"""Specific filter for the ip utility to that does match exec."""
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
# Network namespaces currently require root
|
||||
# require <ns> argument
|
||||
if self.run_as != "root" or len(userargs) < 4:
|
||||
if not userargs or self.run_as != "root" or len(userargs) < 4:
|
||||
return False
|
||||
|
||||
return (
|
||||
@@ -386,7 +402,7 @@ class IpNetnsExecFilter(ChainingFilter):
|
||||
and userargs[2] in EXEC_VARS
|
||||
)
|
||||
|
||||
def exec_args(self, userargs):
|
||||
def exec_args(self, userargs: list[str]) -> list[str]:
|
||||
args = userargs[4:]
|
||||
if args:
|
||||
args[0] = os.path.basename(args[0])
|
||||
@@ -400,7 +416,7 @@ class ChainingRegExpFilter(ChainingFilter):
|
||||
specified as the arguments must be also allowed to execute directly.
|
||||
"""
|
||||
|
||||
def match(self, userargs):
|
||||
def match(self, userargs: list[str] | None) -> bool:
|
||||
# Early skip if number of args is smaller than the filter
|
||||
if not userargs or len(self.args) > len(userargs):
|
||||
return False
|
||||
@@ -416,7 +432,7 @@ class ChainingRegExpFilter(ChainingFilter):
|
||||
# ALLOW: All arguments matched
|
||||
return True
|
||||
|
||||
def exec_args(self, userargs):
|
||||
def exec_args(self, userargs: list[str]) -> list[str]:
|
||||
args = userargs[len(self.args) :]
|
||||
if args:
|
||||
args[0] = os.path.basename(args[0])
|
||||
|
||||
@@ -20,13 +20,14 @@ from multiprocessing import connection
|
||||
from multiprocessing import managers
|
||||
import socket
|
||||
import struct
|
||||
from typing import Any
|
||||
import weakref
|
||||
|
||||
from oslo_rootwrap import wrapper
|
||||
|
||||
|
||||
class RpcJSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
def default(self, o: object) -> Any:
|
||||
# We need to pass bytes unchanged as they are expected in arguments for
|
||||
# and are result of Popen.communicate()
|
||||
if isinstance(o, bytes):
|
||||
@@ -46,7 +47,7 @@ class RpcJSONEncoder(json.JSONEncoder):
|
||||
|
||||
|
||||
# Parse whatever RpcJSONEncoder supplied us with
|
||||
def rpc_object_hook(obj):
|
||||
def rpc_object_hook(obj: dict[str, Any]) -> Any:
|
||||
if "__exception__" in obj:
|
||||
type_name = obj.pop("__exception__")
|
||||
if type_name not in ("NoFilterMatched", "FilterMatchNotExecutable"):
|
||||
@@ -60,7 +61,12 @@ def rpc_object_hook(obj):
|
||||
|
||||
|
||||
class JsonListener:
|
||||
def __init__(self, address, backlog=1):
|
||||
address: str
|
||||
closed: bool
|
||||
_socket: socket.socket
|
||||
_accepted: weakref.WeakSet["JsonConnection"]
|
||||
|
||||
def __init__(self, address: str, backlog: int = 1) -> None:
|
||||
self.address = address
|
||||
self._socket = socket.socket(socket.AF_UNIX)
|
||||
try:
|
||||
@@ -73,7 +79,7 @@ class JsonListener:
|
||||
self.closed = False
|
||||
self._accepted = weakref.WeakSet()
|
||||
|
||||
def accept(self):
|
||||
def accept(self) -> "JsonConnection":
|
||||
while True:
|
||||
try:
|
||||
s, _ = self._socket.accept()
|
||||
@@ -89,65 +95,67 @@ class JsonListener:
|
||||
self._accepted.add(conn)
|
||||
return conn
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if not self.closed:
|
||||
self._socket.shutdown(socket.SHUT_RDWR)
|
||||
self._socket.close()
|
||||
self.closed = True
|
||||
|
||||
def get_accepted(self):
|
||||
def get_accepted(self) -> weakref.WeakSet["JsonConnection"]:
|
||||
return self._accepted
|
||||
|
||||
|
||||
if hasattr(managers.Server, 'accepter'):
|
||||
# In Python 3 accepter() thread has infinite loop. We break it with
|
||||
# EOFError, so we should silence this error here.
|
||||
def silent_accepter(self):
|
||||
def silent_accepter(self: managers.Server) -> None:
|
||||
try:
|
||||
old_accepter(self)
|
||||
except EOFError:
|
||||
pass
|
||||
|
||||
old_accepter = managers.Server.accepter
|
||||
managers.Server.accepter = silent_accepter
|
||||
managers.Server.accepter = silent_accepter # type: ignore[method-assign]
|
||||
|
||||
|
||||
class JsonConnection:
|
||||
def __init__(self, sock):
|
||||
_socket: socket.socket
|
||||
|
||||
def __init__(self, sock: socket.socket) -> None:
|
||||
sock.setblocking(True)
|
||||
self._socket = sock
|
||||
|
||||
def send_bytes(self, s):
|
||||
def send_bytes(self, s: bytes) -> None:
|
||||
self._socket.sendall(struct.pack('!Q', len(s)))
|
||||
self._socket.sendall(s)
|
||||
|
||||
def recv_bytes(self, maxsize=None):
|
||||
def recv_bytes(self, maxsize: int | None = None) -> bytes:
|
||||
item = struct.unpack('!Q', self.recvall(8))[0]
|
||||
if maxsize is not None and item > maxsize:
|
||||
raise RuntimeError("Too big message received")
|
||||
s = self.recvall(item)
|
||||
return s
|
||||
|
||||
def send(self, obj):
|
||||
def send(self, obj: object) -> None:
|
||||
s = self.dumps(obj)
|
||||
self.send_bytes(s)
|
||||
|
||||
def recv(self):
|
||||
def recv(self) -> Any:
|
||||
s = self.recv_bytes()
|
||||
return self.loads(s)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self._socket.close()
|
||||
|
||||
def half_close(self):
|
||||
def half_close(self) -> None:
|
||||
self._socket.shutdown(socket.SHUT_RD)
|
||||
|
||||
# We have to use slow version of recvall with eventlet because of a bug in
|
||||
# GreenSocket.recv_into:
|
||||
# https://bitbucket.org/eventlet/eventlet/pull-request/41
|
||||
def _recvall_slow(self, size):
|
||||
def _recvall_slow(self, size: int) -> bytes:
|
||||
remaining = size
|
||||
res = []
|
||||
res: list[bytes] = []
|
||||
while remaining:
|
||||
piece = self._socket.recv(remaining)
|
||||
if not piece:
|
||||
@@ -156,7 +164,7 @@ class JsonConnection:
|
||||
remaining -= len(piece)
|
||||
return b''.join(res)
|
||||
|
||||
def recvall(self, size):
|
||||
def recvall(self, size: int) -> bytes:
|
||||
buf = bytearray(size)
|
||||
mem = memoryview(buf)
|
||||
got = 0
|
||||
@@ -170,11 +178,11 @@ class JsonConnection:
|
||||
return bytes(buf)
|
||||
|
||||
@staticmethod
|
||||
def dumps(obj):
|
||||
def dumps(obj: object) -> bytes:
|
||||
return json.dumps(obj, cls=RpcJSONEncoder).encode('utf-8')
|
||||
|
||||
@staticmethod
|
||||
def loads(s):
|
||||
def loads(s: bytes) -> Any:
|
||||
res = json.loads(s.decode('utf-8'), object_hook=rpc_object_hook)
|
||||
try:
|
||||
kind = res[0]
|
||||
@@ -190,11 +198,11 @@ class JsonConnection:
|
||||
|
||||
|
||||
class JsonClient(JsonConnection):
|
||||
def __init__(self, address, authkey=None):
|
||||
def __init__(self, address: str, authkey: bytes | None = None) -> None:
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
sock.setblocking(True)
|
||||
sock.connect(address)
|
||||
super().__init__(sock)
|
||||
if authkey is not None:
|
||||
connection.answer_challenge(self, authkey)
|
||||
connection.deliver_challenge(self, authkey)
|
||||
connection.answer_challenge(self, authkey) # type: ignore[arg-type]
|
||||
connection.deliver_challenge(self, authkey) # type: ignore[arg-type]
|
||||
|
||||
0
oslo_rootwrap/py.typed
Normal file
0
oslo_rootwrap/py.typed
Normal file
108
oslo_rootwrap/tests/functional_base.py
Normal file
108
oslo_rootwrap/tests/functional_base.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2014 Mirantis Inc.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import pwd
|
||||
|
||||
try:
|
||||
import eventlet
|
||||
except ImportError:
|
||||
eventlet = None
|
||||
|
||||
import fixtures
|
||||
import testtools
|
||||
|
||||
from oslo_rootwrap import cmd
|
||||
|
||||
|
||||
class _FunctionalBase(testtools.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
tmpdir = self.useFixture(fixtures.TempDir()).path
|
||||
self.config_file = os.path.join(tmpdir, 'rootwrap.conf')
|
||||
self.later_cmd = os.path.join(tmpdir, 'later_install_cmd')
|
||||
filters_dir = os.path.join(tmpdir, 'filters.d')
|
||||
filters_file = os.path.join(tmpdir, 'filters.d', 'test.filters')
|
||||
os.mkdir(filters_dir)
|
||||
with open(self.config_file, 'w') as f:
|
||||
f.write(f"""[DEFAULT]
|
||||
filters_path={filters_dir}
|
||||
daemon_timeout=10
|
||||
exec_dirs=/bin""")
|
||||
with open(filters_file, 'w') as f:
|
||||
f.write(f"""[Filters]
|
||||
echo: CommandFilter, /bin/echo, root
|
||||
cat: CommandFilter, /bin/cat, root
|
||||
sh: CommandFilter, /bin/sh, root
|
||||
id: CommandFilter, /usr/bin/id, nobody
|
||||
unknown_cmd: CommandFilter, /unknown/unknown_cmd, root
|
||||
later_install_cmd: CommandFilter, {self.later_cmd}, root
|
||||
""")
|
||||
|
||||
def _test_run_once(self, expect_byte: bool = True) -> None:
|
||||
code, out, err = self.execute(['echo', 'teststr'])
|
||||
self.assertEqual(0, code)
|
||||
expect_out: str | bytes
|
||||
expect_err: str | bytes
|
||||
if expect_byte:
|
||||
expect_out = b'teststr\n'
|
||||
expect_err = b''
|
||||
else:
|
||||
expect_out = 'teststr\n'
|
||||
expect_err = ''
|
||||
self.assertEqual(expect_out, out)
|
||||
self.assertEqual(expect_err, err)
|
||||
|
||||
def _test_run_with_stdin(self, expect_byte: bool = True) -> None:
|
||||
code, out, err = self.execute(['cat'], stdin=b'teststr')
|
||||
self.assertEqual(0, code)
|
||||
expect_out: str | bytes
|
||||
expect_err: str | bytes
|
||||
if expect_byte:
|
||||
expect_out = b'teststr'
|
||||
expect_err = b''
|
||||
else:
|
||||
expect_out = 'teststr'
|
||||
expect_err = ''
|
||||
self.assertEqual(expect_out, out)
|
||||
self.assertEqual(expect_err, err)
|
||||
|
||||
def test_run_with_path(self):
|
||||
code, out, err = self.execute(['/bin/echo', 'teststr'])
|
||||
self.assertEqual(0, code)
|
||||
|
||||
def test_run_with_bogus_path(self):
|
||||
code, out, err = self.execute(['/home/bob/bin/echo', 'teststr'])
|
||||
self.assertEqual(cmd.RC_UNAUTHORIZED, code)
|
||||
|
||||
def test_run_command_not_found(self):
|
||||
code, out, err = self.execute(['unknown_cmd'])
|
||||
self.assertEqual(cmd.RC_NOEXECFOUND, code)
|
||||
|
||||
def test_run_unauthorized_command(self):
|
||||
code, out, err = self.execute(['unauthorized_cmd'])
|
||||
self.assertEqual(cmd.RC_UNAUTHORIZED, code)
|
||||
|
||||
def test_run_as(self):
|
||||
if os.getuid() != 0:
|
||||
self.skip('Test requires root (for setuid)')
|
||||
|
||||
# Should run as 'nobody'
|
||||
code, out, err = self.execute(['id', '-u'])
|
||||
self.assertEqual(pwd.getpwnam('nobody').pw_uid, int(out.strip()))
|
||||
|
||||
# Should run as 'root'
|
||||
code, out, err = self.execute(['sh', '-c', 'id -u'])
|
||||
self.assertEqual(0, int(out.strip()))
|
||||
@@ -17,7 +17,6 @@ import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import pwd
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
@@ -31,94 +30,16 @@ except ImportError:
|
||||
eventlet = None
|
||||
|
||||
import fixtures
|
||||
import testtools
|
||||
from testtools import content
|
||||
|
||||
|
||||
from oslo_rootwrap import client
|
||||
from oslo_rootwrap import cmd
|
||||
from oslo_rootwrap import subprocess
|
||||
from oslo_rootwrap.tests import functional_base
|
||||
from oslo_rootwrap.tests import run_daemon
|
||||
|
||||
|
||||
class _FunctionalBase:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
tmpdir = self.useFixture(fixtures.TempDir()).path
|
||||
self.config_file = os.path.join(tmpdir, 'rootwrap.conf')
|
||||
self.later_cmd = os.path.join(tmpdir, 'later_install_cmd')
|
||||
filters_dir = os.path.join(tmpdir, 'filters.d')
|
||||
filters_file = os.path.join(tmpdir, 'filters.d', 'test.filters')
|
||||
os.mkdir(filters_dir)
|
||||
with open(self.config_file, 'w') as f:
|
||||
f.write(f"""[DEFAULT]
|
||||
filters_path={filters_dir}
|
||||
daemon_timeout=10
|
||||
exec_dirs=/bin""")
|
||||
with open(filters_file, 'w') as f:
|
||||
f.write(f"""[Filters]
|
||||
echo: CommandFilter, /bin/echo, root
|
||||
cat: CommandFilter, /bin/cat, root
|
||||
sh: CommandFilter, /bin/sh, root
|
||||
id: CommandFilter, /usr/bin/id, nobody
|
||||
unknown_cmd: CommandFilter, /unknown/unknown_cmd, root
|
||||
later_install_cmd: CommandFilter, {self.later_cmd}, root
|
||||
""")
|
||||
|
||||
def _test_run_once(self, expect_byte=True):
|
||||
code, out, err = self.execute(['echo', 'teststr'])
|
||||
self.assertEqual(0, code)
|
||||
if expect_byte:
|
||||
expect_out = b'teststr\n'
|
||||
expect_err = b''
|
||||
else:
|
||||
expect_out = 'teststr\n'
|
||||
expect_err = ''
|
||||
self.assertEqual(expect_out, out)
|
||||
self.assertEqual(expect_err, err)
|
||||
|
||||
def _test_run_with_stdin(self, expect_byte=True):
|
||||
code, out, err = self.execute(['cat'], stdin=b'teststr')
|
||||
self.assertEqual(0, code)
|
||||
if expect_byte:
|
||||
expect_out = b'teststr'
|
||||
expect_err = b''
|
||||
else:
|
||||
expect_out = 'teststr'
|
||||
expect_err = ''
|
||||
self.assertEqual(expect_out, out)
|
||||
self.assertEqual(expect_err, err)
|
||||
|
||||
def test_run_with_path(self):
|
||||
code, out, err = self.execute(['/bin/echo', 'teststr'])
|
||||
self.assertEqual(0, code)
|
||||
|
||||
def test_run_with_bogus_path(self):
|
||||
code, out, err = self.execute(['/home/bob/bin/echo', 'teststr'])
|
||||
self.assertEqual(cmd.RC_UNAUTHORIZED, code)
|
||||
|
||||
def test_run_command_not_found(self):
|
||||
code, out, err = self.execute(['unknown_cmd'])
|
||||
self.assertEqual(cmd.RC_NOEXECFOUND, code)
|
||||
|
||||
def test_run_unauthorized_command(self):
|
||||
code, out, err = self.execute(['unauthorized_cmd'])
|
||||
self.assertEqual(cmd.RC_UNAUTHORIZED, code)
|
||||
|
||||
def test_run_as(self):
|
||||
if os.getuid() != 0:
|
||||
self.skip('Test requires root (for setuid)')
|
||||
|
||||
# Should run as 'nobody'
|
||||
code, out, err = self.execute(['id', '-u'])
|
||||
self.assertEqual(pwd.getpwnam('nobody').pw_uid, int(out.strip()))
|
||||
|
||||
# Should run as 'root'
|
||||
code, out, err = self.execute(['sh', '-c', 'id -u'])
|
||||
self.assertEqual(0, int(out.strip()))
|
||||
|
||||
|
||||
class RootwrapTest(_FunctionalBase, testtools.TestCase):
|
||||
class RootwrapTest(functional_base._FunctionalBase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cmd = [
|
||||
@@ -151,7 +72,7 @@ class RootwrapTest(_FunctionalBase, testtools.TestCase):
|
||||
self._test_run_with_stdin(expect_byte=True)
|
||||
|
||||
|
||||
class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
class RootwrapDaemonTest(functional_base._FunctionalBase):
|
||||
def assert_unpatched(self):
|
||||
# We need to verify that these tests are run without eventlet patching
|
||||
if eventlet and eventlet.patcher.is_monkey_patched('socket'):
|
||||
@@ -210,6 +131,7 @@ class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
@self.addCleanup
|
||||
def finalize_client():
|
||||
if self.client._initialized:
|
||||
assert self.client._finalize is not None # narrow type
|
||||
self.client._finalize()
|
||||
|
||||
self.execute = self.client.execute
|
||||
@@ -233,6 +155,7 @@ class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
# Let the client start a daemon
|
||||
self.execute(['cat'])
|
||||
# Make daemon go away
|
||||
assert self.client._process is not None
|
||||
os.kill(self.client._process.pid, signal.SIGTERM)
|
||||
# Expect client to successfully restart daemon and run simple request
|
||||
self.test_run_once()
|
||||
@@ -246,11 +169,13 @@ class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
self.execute(['echo'])
|
||||
restart.assert_called_once()
|
||||
|
||||
def _exec_thread(self, fifo_path):
|
||||
def _exec_thread(self, fifo_path: str) -> None:
|
||||
try:
|
||||
# Run a shell script that signals calling process through FIFO and
|
||||
# then hangs around for 1 sec
|
||||
self._thread_res = self.execute(
|
||||
self._thread_res: (
|
||||
tuple[int, str | bytes, str | bytes] | Exception
|
||||
) = self.execute(
|
||||
['sh', '-c', f'echo > "{fifo_path}"; sleep 1; echo OK']
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -270,6 +195,7 @@ class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
with open(fifo_path) as f:
|
||||
f.readline()
|
||||
# Gracefully kill daemon process
|
||||
assert self.client._process is not None
|
||||
os.kill(self.client._process.pid, signal.SIGTERM)
|
||||
# Expect daemon to wait for our request to finish
|
||||
t.join()
|
||||
@@ -284,10 +210,13 @@ class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
def _test_daemon_cleanup(self):
|
||||
# Start a daemon
|
||||
self.execute(['cat'])
|
||||
assert self.client._manager is not None
|
||||
socket_path = self.client._manager.address
|
||||
assert isinstance(socket_path, str)
|
||||
# Stop it one way or another
|
||||
yield
|
||||
process = self.client._process
|
||||
assert process is not None
|
||||
stop = threading.Event()
|
||||
|
||||
# Start background thread that would kill process in 1 second if it
|
||||
@@ -299,7 +228,7 @@ class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
|
||||
threading.Thread(target=sleep_kill).start()
|
||||
# Wait for process to finish one way or another
|
||||
self.client._process.wait()
|
||||
process.wait()
|
||||
# Notify background thread that process is dead (no need to kill it)
|
||||
stop.set()
|
||||
# Fail if the process got killed by the background thread
|
||||
@@ -318,9 +247,11 @@ class RootwrapDaemonTest(_FunctionalBase, testtools.TestCase):
|
||||
# Run _test_daemon_cleanup stopping daemon as Client instance would
|
||||
# normally do
|
||||
with self._test_daemon_cleanup():
|
||||
assert self.client._finalize is not None # narrow type
|
||||
self.client._finalize()
|
||||
|
||||
def test_daemon_cleanup_signal(self):
|
||||
# Run _test_daemon_cleanup stopping daemon with SIGTERM signal
|
||||
with self._test_daemon_cleanup():
|
||||
assert self.client._process is not None
|
||||
os.kill(self.client._process.pid, signal.SIGTERM)
|
||||
|
||||
@@ -153,6 +153,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
self.assertTrue(f.match(usercmd))
|
||||
self.assertEqual(['/usr/bin/dnsmasq', 'foo'], f.get_command(usercmd))
|
||||
env = f.get_environment(usercmd)
|
||||
assert env is not None
|
||||
self.assertEqual('A', env.get(config_file_arg))
|
||||
self.assertEqual('foobar', env.get('NETWORK_ID'))
|
||||
|
||||
@@ -176,7 +177,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
self.assertTrue(f.match(usercmd))
|
||||
|
||||
# require given environment variables to match
|
||||
self.assertFalse(f.match([envcmd, 'C=ELSE']))
|
||||
self.assertFalse(f.match(envcmd + ['C=ELSE']))
|
||||
self.assertFalse(f.match(['env', 'C=xx']))
|
||||
self.assertFalse(f.match(['env', 'A=xx']))
|
||||
|
||||
@@ -190,6 +191,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
# ensure that the env command is stripped when executing
|
||||
self.assertEqual(realcmd, f.exec_args(usercmd))
|
||||
env = f.get_environment(usercmd)
|
||||
assert env is not None
|
||||
# check that environment variables are set
|
||||
self.assertEqual('/some/thing', env.get('A'))
|
||||
self.assertEqual('somethingelse', env.get('B'))
|
||||
@@ -209,6 +211,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
self.assertEqual(realcmd, f.get_command(envset + realcmd))
|
||||
|
||||
env = f.get_environment(envset + realcmd)
|
||||
assert env is not None
|
||||
# check that environment variables are set
|
||||
self.assertEqual('/some/thing', env.get('A'))
|
||||
self.assertEqual('somethingelse', env.get('B'))
|
||||
@@ -227,14 +230,14 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
f = filters.KillFilter("root", "/bin/cat", "-9", "-HUP")
|
||||
f2 = filters.KillFilter("root", "/usr/bin/cat", "-9", "-HUP")
|
||||
f3 = filters.KillFilter("root", "/usr/bin/coreutils", "-9", "-HUP")
|
||||
usercmd = ['kill', '-ALRM', p.pid]
|
||||
usercmd = ['kill', '-ALRM', str(p.pid)]
|
||||
# Incorrect signal should fail
|
||||
self.assertFalse(f.match(usercmd) or f2.match(usercmd))
|
||||
usercmd = ['kill', p.pid]
|
||||
usercmd = ['kill', str(p.pid)]
|
||||
# Providing no signal should fail
|
||||
self.assertFalse(f.match(usercmd) or f2.match(usercmd))
|
||||
# Providing matching signal should be allowed
|
||||
usercmd = ['kill', '-9', p.pid]
|
||||
usercmd = ['kill', '-9', str(p.pid)]
|
||||
self.assertTrue(
|
||||
f.match(usercmd) or f2.match(usercmd) or f3.match(usercmd)
|
||||
)
|
||||
@@ -242,13 +245,13 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
f = filters.KillFilter("root", "/bin/cat")
|
||||
f2 = filters.KillFilter("root", "/usr/bin/cat")
|
||||
f3 = filters.KillFilter("root", "/usr/bin/coreutils")
|
||||
usercmd = ['kill', os.getpid()]
|
||||
usercmd = ['kill', str(os.getpid())]
|
||||
# Our own PID does not match /bin/sleep, so it should fail
|
||||
self.assertFalse(f.match(usercmd) or f2.match(usercmd))
|
||||
usercmd = ['kill', 999999]
|
||||
usercmd = ['kill', '999999']
|
||||
# Nonexistent PID should fail
|
||||
self.assertFalse(f.match(usercmd) or f2.match(usercmd))
|
||||
usercmd = ['kill', p.pid]
|
||||
usercmd = ['kill', str(p.pid)]
|
||||
# Providing no signal should work
|
||||
self.assertTrue(
|
||||
f.match(usercmd) or f2.match(usercmd) or f3.match(usercmd)
|
||||
@@ -258,10 +261,10 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
f = filters.KillFilter("root", "cat")
|
||||
f2 = filters.KillFilter("root", "coreutils")
|
||||
# Our own PID does not match so it should fail
|
||||
usercmd = ['kill', os.getpid()]
|
||||
usercmd = ['kill', str(os.getpid())]
|
||||
self.assertFalse(f.match(usercmd))
|
||||
# Filter should find cat in /bin or /usr/bin
|
||||
usercmd = ['kill', p.pid]
|
||||
usercmd = ['kill', str(p.pid)]
|
||||
self.assertTrue(f.match(usercmd) or f2.match(usercmd))
|
||||
# Filter shouldn't be able to find binary in $PATH, so fail
|
||||
with fixtures.EnvironmentVariable("PATH", "/foo:/bar"):
|
||||
@@ -278,7 +281,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
"""Makes sure ValueError from bug 926412 is gone."""
|
||||
f = filters.KillFilter("root", "")
|
||||
# Providing anything other than kill should be False
|
||||
usercmd = ['notkill', 999999]
|
||||
usercmd = ['notkill', '999999']
|
||||
self.assertFalse(f.match(usercmd))
|
||||
# Providing something that is not a pid should be False
|
||||
usercmd = ['kill', 'notapid']
|
||||
@@ -291,7 +294,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
"""Makes sure deleted exe's are killed correctly."""
|
||||
command = "/bin/commandddddd"
|
||||
f = filters.KillFilter("root", command)
|
||||
usercmd = ['kill', 1234]
|
||||
usercmd = ['kill', '1234']
|
||||
# Providing no signal should work
|
||||
with mock.patch('os.readlink') as readlink:
|
||||
readlink.return_value = command + ' (deleted)'
|
||||
@@ -309,7 +312,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
"""Makes sure upgraded exe's are killed correctly."""
|
||||
f = filters.KillFilter("root", "/bin/commandddddd")
|
||||
command = "/bin/commandddddd"
|
||||
usercmd = ['kill', 1234]
|
||||
usercmd = ['kill', '1234']
|
||||
|
||||
def fake_exists(path):
|
||||
return path == command
|
||||
@@ -328,7 +331,7 @@ class RootwrapTestCase(testtools.TestCase):
|
||||
"""Makes sure renamed exe's are killed correctly."""
|
||||
command = "/bin/commandddddd"
|
||||
f = filters.KillFilter("root", command)
|
||||
usercmd = ['kill', 1234]
|
||||
usercmd = ['kill', '1234']
|
||||
|
||||
def fake_os_func(path, *args):
|
||||
return path == command
|
||||
@@ -739,11 +742,17 @@ class PathFilterTestCase(testtools.TestCase):
|
||||
|
||||
|
||||
class RunOneCommandTestCase(testtools.TestCase):
|
||||
def _test_returncode_helper(self, returncode, expected):
|
||||
def _test_returncode_helper(self, returncode: int, expected: int) -> None:
|
||||
with mock.patch.object(wrapper, 'start_subprocess') as mock_start:
|
||||
with mock.patch('sys.exit') as mock_exit:
|
||||
mock_start.return_value.wait.return_value = returncode
|
||||
cmd.run_one_command(None, mock.Mock(), None, None)
|
||||
# Using mocks for testing, so ignore arg types
|
||||
cmd.run_one_command(
|
||||
None, # type: ignore[arg-type]
|
||||
mock.Mock(),
|
||||
None, # type: ignore[arg-type]
|
||||
None, # type: ignore[arg-type]
|
||||
)
|
||||
mock_exit.assert_called_once_with(expected)
|
||||
|
||||
def test_positive_returncode(self):
|
||||
|
||||
@@ -19,6 +19,7 @@ import logging.handlers
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
|
||||
from oslo_rootwrap import filters
|
||||
from oslo_rootwrap import subprocess
|
||||
@@ -38,12 +39,22 @@ class NoFilterMatched(Exception):
|
||||
class FilterMatchNotExecutable(Exception):
|
||||
"""Raised when a filter matched but no executable was found."""
|
||||
|
||||
def __init__(self, match=None, **kwargs):
|
||||
def __init__(
|
||||
self, match: filters.CommandFilter | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
self.match = match
|
||||
|
||||
|
||||
class RootwrapConfig:
|
||||
def __init__(self, config):
|
||||
filters_path: list[str]
|
||||
exec_dirs: list[str]
|
||||
syslog_log_facility: int
|
||||
syslog_log_level: int
|
||||
use_syslog: bool
|
||||
daemon_timeout: int
|
||||
rlimit_nofile: int
|
||||
|
||||
def __init__(self, config: configparser.RawConfigParser) -> None:
|
||||
# filters_path
|
||||
self.filters_path = config.get("DEFAULT", "filters_path").split(",")
|
||||
|
||||
@@ -60,24 +71,27 @@ class RootwrapConfig:
|
||||
if config.has_option("DEFAULT", "syslog_log_facility"):
|
||||
v = config.get("DEFAULT", "syslog_log_facility")
|
||||
facility_names = logging.handlers.SysLogHandler.facility_names
|
||||
self.syslog_log_facility = getattr(
|
||||
facility: int | None = getattr(
|
||||
logging.handlers.SysLogHandler, v, None
|
||||
)
|
||||
if self.syslog_log_facility is None and v in facility_names:
|
||||
self.syslog_log_facility = facility_names.get(v)
|
||||
if self.syslog_log_facility is None:
|
||||
if facility is None and v in facility_names:
|
||||
facility = facility_names.get(v)
|
||||
if facility is None:
|
||||
raise ValueError(f'Unexpected syslog_log_facility: {v}')
|
||||
self.syslog_log_facility = facility
|
||||
else:
|
||||
default_facility = logging.handlers.SysLogHandler.LOG_SYSLOG
|
||||
self.syslog_log_facility = default_facility
|
||||
self.syslog_log_facility = (
|
||||
logging.handlers.SysLogHandler.LOG_SYSLOG
|
||||
)
|
||||
|
||||
# syslog_log_level
|
||||
if config.has_option("DEFAULT", "syslog_log_level"):
|
||||
v = config.get("DEFAULT", "syslog_log_level")
|
||||
level = v.upper()
|
||||
self.syslog_log_level = logging.getLevelName(level)
|
||||
if self.syslog_log_level == f"Level {level}":
|
||||
level_name = v.upper()
|
||||
level: int | str = logging.getLevelName(level_name)
|
||||
if isinstance(level, str):
|
||||
raise ValueError(f'Unexpected syslog_log_level: {v!r}')
|
||||
self.syslog_log_level = level
|
||||
else:
|
||||
self.syslog_log_level = logging.ERROR
|
||||
|
||||
@@ -100,7 +114,7 @@ class RootwrapConfig:
|
||||
self.rlimit_nofile = 1024
|
||||
|
||||
|
||||
def setup_syslog(execname, facility, level):
|
||||
def setup_syslog(execname: str, facility: int, level: int) -> None:
|
||||
try:
|
||||
handler = logging.handlers.SysLogHandler(
|
||||
address='/dev/log', facility=facility
|
||||
@@ -121,7 +135,7 @@ def setup_syslog(execname, facility, level):
|
||||
rootwrap_logger.addHandler(handler)
|
||||
|
||||
|
||||
def build_filter(class_name, *args):
|
||||
def build_filter(class_name: str, *args: Any) -> filters.CommandFilter | None:
|
||||
"""Returns a filter object of class class_name."""
|
||||
if not hasattr(filters, class_name):
|
||||
LOG.warning(
|
||||
@@ -131,10 +145,10 @@ def build_filter(class_name, *args):
|
||||
)
|
||||
return None
|
||||
filterclass = getattr(filters, class_name)
|
||||
return filterclass(*args)
|
||||
return cast(filters.CommandFilter, filterclass(*args))
|
||||
|
||||
|
||||
def load_filters(filters_path):
|
||||
def load_filters(filters_path: list[str]) -> list[filters.CommandFilter]:
|
||||
"""Load filters from a list of directories."""
|
||||
filterlist = []
|
||||
for filterdir in filters_path:
|
||||
@@ -146,8 +160,7 @@ def load_filters(filters_path):
|
||||
filterfilepath = os.path.join(filterdir, filterfile)
|
||||
if not os.path.isfile(filterfilepath):
|
||||
continue
|
||||
kwargs = {"strict": False}
|
||||
filterconfig = configparser.RawConfigParser(**kwargs)
|
||||
filterconfig = configparser.RawConfigParser(strict=False)
|
||||
filterconfig.read(filterfilepath)
|
||||
for name, value in filterconfig.items("Filters"):
|
||||
filterdefinition = [s.strip() for s in value.split(',')]
|
||||
@@ -158,12 +171,17 @@ def load_filters(filters_path):
|
||||
filterlist.append(newfilter)
|
||||
# And always include privsep-helper
|
||||
privsep = build_filter("CommandFilter", "privsep-helper", "root")
|
||||
assert privsep is not None # narrow type
|
||||
privsep.name = "privsep-helper"
|
||||
filterlist.append(privsep)
|
||||
return filterlist
|
||||
|
||||
|
||||
def match_filter(filter_list, userargs, exec_dirs=None):
|
||||
def match_filter(
|
||||
filter_list: list[filters.CommandFilter],
|
||||
userargs: list[str],
|
||||
exec_dirs: list[str] | None = None,
|
||||
) -> filters.CommandFilter:
|
||||
"""Checks user command and arguments through command filters.
|
||||
|
||||
Returns the first matching filter.
|
||||
@@ -180,7 +198,7 @@ def match_filter(filter_list, userargs, exec_dirs=None):
|
||||
if isinstance(f, filters.ChainingFilter):
|
||||
# This command calls exec verify that remaining args
|
||||
# matches another filter.
|
||||
def non_chain_filter(fltr):
|
||||
def non_chain_filter(fltr: filters.CommandFilter) -> bool:
|
||||
return fltr.run_as == f.run_as and not isinstance(
|
||||
fltr, filters.ChainingFilter
|
||||
)
|
||||
@@ -212,7 +230,7 @@ def match_filter(filter_list, userargs, exec_dirs=None):
|
||||
raise NoFilterMatched()
|
||||
|
||||
|
||||
def _getlogin():
|
||||
def _getlogin() -> str | None:
|
||||
try:
|
||||
return os.getlogin()
|
||||
except OSError:
|
||||
@@ -221,10 +239,16 @@ def _getlogin():
|
||||
)
|
||||
|
||||
|
||||
def start_subprocess(filter_list, userargs, exec_dirs=[], log=False, **kwargs):
|
||||
filtermatch = match_filter(filter_list, userargs, exec_dirs)
|
||||
def start_subprocess(
|
||||
filter_list: list[filters.CommandFilter],
|
||||
userargs: list[str],
|
||||
exec_dirs: list[str] | None = None,
|
||||
log: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> subprocess.Popen[bytes]:
|
||||
filtermatch = match_filter(filter_list, userargs, exec_dirs or [])
|
||||
|
||||
command = filtermatch.get_command(userargs, exec_dirs)
|
||||
command = filtermatch.get_command(userargs, exec_dirs or [])
|
||||
if log:
|
||||
LOG.info(
|
||||
"(%s > %s) Executing %s (filter match = %s)",
|
||||
@@ -234,7 +258,7 @@ def start_subprocess(filter_list, userargs, exec_dirs=[], log=False, **kwargs):
|
||||
filtermatch.name,
|
||||
)
|
||||
|
||||
def preexec():
|
||||
def preexec() -> None:
|
||||
# Python installs a SIGPIPE handler by default. This is
|
||||
# usually not what non-Python subprocesses expect.
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
|
||||
|
||||
@@ -25,6 +25,7 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -47,6 +48,21 @@ docstring-code-format = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E5", "E7", "E9", "F", "G", "LOG", "S", "UP"]
|
||||
ignore = ["S101"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"benchmark/benchmark.py" = ["S"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
show_column_numbers = true
|
||||
show_error_context = true
|
||||
strict = true
|
||||
disable_error_code = ["import-untyped"]
|
||||
exclude = "(?x)(doc | releasenotes)"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["oslo_rootwrap.tests.*"]
|
||||
disallow_untyped_calls = false
|
||||
disallow_untyped_defs = false
|
||||
disallow_subclassing_any = false
|
||||
|
||||
@@ -3,3 +3,4 @@
|
||||
# you find any incorrect lower bounds, let us know or propose a fix.
|
||||
|
||||
debtcollector>=3.0.0 # Apache-2.0
|
||||
pbr>=6.1.1 # Apache-2.0
|
||||
|
||||
21
tox.ini
21
tox.ini
@@ -53,12 +53,29 @@ commands =
|
||||
sphinx-build -a -E -W -d releasenotes/build/doctrees --keep-going -b html releasenotes/source releasenotes/build/html
|
||||
|
||||
[testenv:pep8]
|
||||
skip_install = true
|
||||
description =
|
||||
Run style checks.
|
||||
deps =
|
||||
pre-commit>=2.6.0 # MIT
|
||||
pre-commit
|
||||
{[testenv:mypy]deps}
|
||||
commands =
|
||||
pre-commit run -a
|
||||
{[testenv:mypy]commands}
|
||||
|
||||
[testenv:mypy]
|
||||
description =
|
||||
Run type checks.
|
||||
deps =
|
||||
{[testenv]deps}
|
||||
mypy
|
||||
commands =
|
||||
mypy --cache-dir="{envdir}/mypy_cache" {posargs:oslo_rootwrap}
|
||||
|
||||
[flake8]
|
||||
# We only enable the hacking (H) checks
|
||||
select = H
|
||||
|
||||
[hacking]
|
||||
import_exceptions =
|
||||
collections.abc
|
||||
typing
|
||||
|
||||
Reference in New Issue
Block a user