Add timeout to PrivContext and entrypoint_with_timeout decorator

entrypoint_with_timeout decorator can be used with a timeout parameter,
if the timeout is reached PrivsepTimeout is raised.
The PrivContext has timeout variable, which will be used for all
functions decorated with entrypoint, and PrivsepTimeout is raised if
timeout is reached.

Co-authored-by: Rodolfo Alonso <ralonsoh@redhat.com>
Change-Id: Ie3b1fc255c0c05fd5403b90ef49b954fe397fb77
Related-Bug: #1930401
This commit is contained in:
elajkat 2021-06-08 18:09:31 +02:00
parent fa47d53dcb
commit f7f3349d6a
8 changed files with 202 additions and 50 deletions

View File

@ -30,6 +30,31 @@ defines a sys_admin_pctxt with ``CAP_CHOWN``, ``CAP_DAC_OVERRIDE``,
capabilities.CAP_SYS_ADMIN],
)
Defining a context with timeout
-------------------------------
It is possible to initialize PrivContext with timeout::
from oslo_privsep import capabilities
from oslo_privsep import priv_context
dhcp_release_cmd = priv_context.PrivContext(
__name__,
cfg_section='privsep_dhcp_release',
pypath=__name__ + '.dhcp_release_cmd',
capabilities=[caps.CAP_SYS_ADMIN,
caps.CAP_NET_ADMIN],
timeout=5
)
``PrivsepTimeout`` is raised if timeout is reached.
.. warning::
The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.
Defining a privileged function
==============================
@ -51,6 +76,36 @@ generic ``update_file(filename, content)`` was created, it could be used to
overwrite any file in the filesystem, allowing easy escalation to root
rights. That would defeat the whole purpose of oslo.privsep.
Defining a privileged function with timeout
-------------------------------------------
It is possible to use ``entrypoint_with_timeout`` decorator::
from oslo_privsep import daemon
from neutron import privileged
@privileged.default.entrypoint_with_timeout(timeout=5)
def get_link_devices(namespace, **kwargs):
try:
with get_iproute(namespace) as ip:
return make_serializable(ip.get_links(**kwargs))
except OSError as e:
if e.errno == errno.ENOENT:
raise NetworkNamespaceNotFound(netns_name=namespace)
raise
except daemon.FailedToDropPrivileges:
raise
except daemon.PrivsepTimeout:
raise
``PrivsepTimeout`` is raised if timeout is reached.
.. warning::
The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.
Using a privileged function
===========================

View File

@ -20,6 +20,8 @@ python datatypes. Msgpack 'raw' is assumed to be a valid utf8 string
converted to tuples during serialization/deserialization.
"""
import datetime
import enum
import logging
import socket
import threading
@ -28,22 +30,24 @@ import msgpack
import six
from oslo_privsep._i18n import _
from oslo_utils import uuidutils
LOG = logging.getLogger(__name__)
try:
import greenlet
@enum.unique
class Message(enum.IntEnum):
"""Types of messages sent across the communication channel"""
PING = 1
PONG = 2
CALL = 3
RET = 4
ERR = 5
LOG = 6
def _get_thread_ident():
# This returns something sensible, even if the current thread
# isn't a greenthread
return id(greenlet.getcurrent())
except ImportError:
def _get_thread_ident():
return threading.current_thread().ident
class PrivsepTimeout(Exception):
pass
class Serializer(object):
@ -89,10 +93,11 @@ class Deserializer(six.Iterator):
class Future(object):
"""A very simple object to track the return of a function call"""
def __init__(self, lock):
def __init__(self, lock, timeout=None):
self.condvar = threading.Condition(lock)
self.error = None
self.data = None
self.timeout = timeout
def set_result(self, data):
"""Must already be holding lock used in constructor"""
@ -106,7 +111,16 @@ class Future(object):
def result(self):
"""Must already be holding lock used in constructor"""
self.condvar.wait()
before = datetime.datetime.now()
if not self.condvar.wait(timeout=self.timeout):
now = datetime.datetime.now()
LOG.warning('Timeout while executing a command, timeout: %s, '
'time elapsed: %s', self.timeout,
(now - before).total_seconds())
return (Message.ERR.value,
'%s.%s' % (PrivsepTimeout.__module__,
PrivsepTimeout.__name__),
'')
if self.error is not None:
raise self.error
return self.data
@ -138,8 +152,9 @@ class ClientChannel(object):
else:
with self.lock:
if msgid not in self.outstanding_msgs:
raise AssertionError("msgid should in "
"outstanding_msgs.")
LOG.warning("msgid should be in oustanding_msgs, it is"
"possible that timeout is reached!")
continue
self.outstanding_msgs[msgid].set_result(data)
# EOF. Perhaps the privileged process exited?
@ -158,13 +173,14 @@ class ClientChannel(object):
"""Received OOB message. Subclasses might want to override this."""
pass
def send_recv(self, msg):
myid = _get_thread_ident()
future = Future(self.lock)
def send_recv(self, msg, timeout=None):
myid = uuidutils.generate_uuid()
while myid in self.outstanding_msgs:
LOG.warning("myid shoudn't be in outstanding_msgs.")
myid = uuidutils.generate_uuid()
future = Future(self.lock, timeout)
with self.lock:
if myid in self.outstanding_msgs:
raise AssertionError("myid shoudn't be in outstanding_msgs.")
self.outstanding_msgs[myid] = future
try:
self.writer.send((myid, msg))

View File

@ -109,17 +109,6 @@ class StdioFd(enum.IntEnum):
STDERR = 2
@enum.unique
class Message(enum.IntEnum):
"""Types of messages sent across the communication channel"""
PING = 1
PONG = 2
CALL = 3
RET = 4
ERR = 5
LOG = 6
class FailedToDropPrivileges(Exception):
pass
@ -187,7 +176,7 @@ class PrivsepLogHandler(pylogging.Handler):
data['msg'] = record.getMessage()
data['args'] = ()
self.channel.send((None, (Message.LOG, data)))
self.channel.send((None, (comm.Message.LOG, data)))
class _ClientChannel(comm.ClientChannel):
@ -201,8 +190,8 @@ class _ClientChannel(comm.ClientChannel):
def exchange_ping(self):
try:
# exchange "ready" messages
reply = self.send_recv((Message.PING.value,))
success = reply[0] == Message.PONG
reply = self.send_recv((comm.Message.PING.value,))
success = reply[0] == comm.Message.PONG
except Exception as e:
self.log.exception('Error while sending initial PING to privsep: '
'%s', e)
@ -212,12 +201,13 @@ class _ClientChannel(comm.ClientChannel):
self.log.critical(msg)
raise FailedToDropPrivileges(msg)
def remote_call(self, name, args, kwargs):
result = self.send_recv((Message.CALL.value, name, args, kwargs))
if result[0] == Message.RET:
def remote_call(self, name, args, kwargs, timeout):
result = self.send_recv((comm.Message.CALL.value, name, args, kwargs),
timeout)
if result[0] == comm.Message.RET:
# (RET, return value)
return result[1]
elif result[0] == Message.ERR:
elif result[0] == comm.Message.ERR:
# (ERR, exc_type, args)
#
# TODO(gus): see what can be done to preserve traceback
@ -228,7 +218,7 @@ class _ClientChannel(comm.ClientChannel):
raise ProtocolError(_('Unexpected response: %r') % result)
def out_of_band(self, msg):
if msg[0] == Message.LOG:
if msg[0] == comm.Message.LOG:
# (LOG, LogRecord __dict__)
message = {encodeutils.safe_decode(k): v
for k, v in msg[1].items()}
@ -470,11 +460,11 @@ class Daemon(object):
:return: A tuple of the return status, optional call output, and
optional error information.
"""
if cmd == Message.PING:
return (Message.PONG.value,)
if cmd == comm.Message.PING:
return (comm.Message.PONG.value,)
try:
if cmd != Message.CALL:
if cmd != comm.Message.CALL:
raise ProtocolError(_('Unknown privsep cmd: %s') % cmd)
# Extract the callable and arguments
@ -485,14 +475,14 @@ class Daemon(object):
raise NameError(msg)
ret = func(*f_args, **f_kwargs)
return (Message.RET.value, ret)
return (comm.Message.RET.value, ret)
except Exception as e:
LOG.debug(
'privsep: Exception during request[%(msgid)s]: '
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
cls = e.__class__
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
return (Message.ERR.value, cls_name, e.args)
return (comm.Message.ERR.value, cls_name, e.args)
def _create_done_callback(self, msgid):
"""Creates a future callback to receive command execution results.
@ -520,7 +510,7 @@ class Daemon(object):
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
cls = e.__class__
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
reply = (Message.ERR.value, cls_name, e.args)
reply = (comm.Message.ERR.value, cls_name, e.args)
try:
channel.send((msgid, reply))
except IOError:

View File

@ -20,6 +20,7 @@ import unittest
from oslo_config import fixture as config_fixture
from oslotest import base
from oslo_privsep import comm
from oslo_privsep import priv_context
@ -30,6 +31,14 @@ test_context = priv_context.PrivContext(
capabilities=[],
)
test_context_with_timeout = priv_context.PrivContext(
__name__,
cfg_section='privsep',
pypath=__name__ + '.test_context_with_timeout',
capabilities=[],
timeout=0.03
)
@test_context.entrypoint
def sleep():
@ -37,6 +46,18 @@ def sleep():
time.sleep(.001)
@test_context.entrypoint_with_timeout(0.03)
def sleep_with_timeout(long_timeout=0.04):
time.sleep(long_timeout)
return 42
@test_context_with_timeout.entrypoint
def sleep_with_t_context(long_timeout=0.04):
time.sleep(long_timeout)
return 42
@test_context.entrypoint
def one():
return 1
@ -65,6 +86,28 @@ class TestDaemon(base.BaseTestCase):
# Make sure the daemon is still working
self.assertEqual(1, one())
def test_entrypoint_with_timeout(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
self.assertRaises(comm.PrivsepTimeout, sleep_with_timeout)
def test_entrypoint_with_timeout_pass(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
res = sleep_with_timeout(0.01)
self.assertEqual(42, res)
def test_context_with_timeout(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
self.assertRaises(comm.PrivsepTimeout, sleep_with_t_context)
def test_context_with_timeout_pass(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
res = sleep_with_t_context(0.01)
self.assertEqual(42, res)
def test_logging(self):
logs()
self.assertIn('foo', self.log_fixture.logger.output)

View File

@ -128,7 +128,8 @@ def init(root_helper=None):
class PrivContext(object):
def __init__(self, prefix, cfg_section='privsep', pypath=None,
capabilities=None, logger_name='oslo_privsep.daemon'):
capabilities=None, logger_name='oslo_privsep.daemon',
timeout=None):
# Note that capabilities=[] means retaining no capabilities
# and leaves even uid=0 with no powers except being able to
@ -156,6 +157,7 @@ class PrivContext(object):
default=capabilities)
cfg.CONF.set_default('logger_name', group=cfg_section,
default=logger_name)
self.timeout = timeout
@property
def conf(self):
@ -221,7 +223,22 @@ class PrivContext(object):
def entrypoint(self, func):
"""This is intended to be used as a decorator."""
return self._entrypoint(func)
def entrypoint_with_timeout(self, timeout):
"""This is intended to be used as a decorator with timeout."""
def wrap(func):
@functools.wraps(func)
def inner(*args, **kwargs):
f = self._entrypoint(func)
return f(*args, _wrap_timeout=timeout, **kwargs)
setattr(inner, _ENTRYPOINT_ATTR, self)
return inner
return wrap
def _entrypoint(self, func):
if not func.__module__.startswith(self.prefix):
raise AssertionError('%r entrypoints must be below "%s"' %
(self, self.prefix))
@ -242,7 +259,7 @@ class PrivContext(object):
def is_entrypoint(self, func):
return getattr(func, _ENTRYPOINT_ATTR, None) is self
def _wrap(self, func, *args, **kwargs):
def _wrap(self, func, *args, _wrap_timeout=None, **kwargs):
if self.client_mode:
name = '%s.%s' % (func.__module__, func.__name__)
if self.channel is not None and not self.channel.running:
@ -250,7 +267,9 @@ class PrivContext(object):
self.stop()
if self.channel is None:
self.start()
return self.channel.remote_call(name, args, kwargs)
r_call_timeout = _wrap_timeout or self.timeout
return self.channel.remote_call(name, args, kwargs,
r_call_timeout)
else:
return func(*args, **kwargs)

View File

@ -216,7 +216,7 @@ class ClientChannelTestCase(base.BaseTestCase):
@mock.patch.object(daemon.LOG.logger, 'handle')
def test_out_of_band_log_message(self, handle_mock):
message = [daemon.Message.LOG, self.DICT]
message = [comm.Message.LOG, self.DICT]
self.assertEqual(self.client_channel.log, daemon.LOG)
with mock.patch.object(pylogging, 'makeLogRecord') as mock_make_log, \
mock.patch.object(daemon.LOG, 'isEnabledFor',
@ -229,7 +229,7 @@ class ClientChannelTestCase(base.BaseTestCase):
def test_out_of_band_not_log_message(self):
with mock.patch.object(daemon.LOG, 'warning') as mock_warning:
self.client_channel.out_of_band([daemon.Message.PING])
self.client_channel.out_of_band([comm.Message.PING])
mock_warning.assert_called_once()
@mock.patch.object(daemon.logging, 'getLogger')
@ -245,7 +245,7 @@ class ClientChannelTestCase(base.BaseTestCase):
get_logger_mock.assert_called_once_with(logger_name)
self.assertEqual(get_logger_mock.return_value, channel.log)
message = [daemon.Message.LOG, self.DICT]
message = [comm.Message.LOG, self.DICT]
channel.out_of_band(message)
make_log_mock.assert_called_once_with(self.EXPECTED)

View File

@ -19,10 +19,12 @@ import pipes
import platform
import sys
import tempfile
import time
from unittest import mock
import testtools
from oslo_privsep import comm
from oslo_privsep import daemon
from oslo_privsep import priv_context
from oslo_privsep.tests import testctx
@ -40,6 +42,12 @@ def add1(arg):
return arg + 1
@testctx.context.entrypoint_with_timeout(0.2)
def do_some_long(long_timeout=0.4):
time.sleep(long_timeout)
return 42
class CustomError(Exception):
def __init__(self, code, msg):
super(CustomError, self).__init__(code, msg)
@ -188,6 +196,16 @@ class RootwrapTest(testctx.TestContextTestCase):
priv_pid = priv_getpid()
self.assertNotMyPid(priv_pid)
def test_long_call_with_timeout(self):
self.assertRaises(
comm.PrivsepTimeout,
do_some_long
)
def test_long_call_within_timeout(self):
res = do_some_long(0.001)
self.assertEqual(42, res)
@testtools.skipIf(platform.system() != 'Linux',
'works only on Linux platform.')

View File

@ -0,0 +1,11 @@
---
features:
- |
Add ``timeout`` as parameter to ``PrivContext`` and add
``entrypoint_with_timeout`` decorator to cover the issues with
commands which take random time to finish.
``PrivsepTimeout`` is raised if timeout is reached.
``Warning``: The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.