Add timeout to PrivContext and entrypoint_with_timeout decorator

entrypoint_with_timeout decorator can be used with a timeout parameter,
if the timeout is reached PrivsepTimeout is raised.
The PrivContext has timeout variable, which will be used for all
functions decorated with entrypoint, and PrivsepTimeout is raised if
timeout is reached.

Co-authored-by: Rodolfo Alonso <ralonsoh@redhat.com>
Change-Id: Ie3b1fc255c0c05fd5403b90ef49b954fe397fb77
Related-Bug: #1930401
This commit is contained in:
elajkat 2021-06-08 18:09:31 +02:00
parent fa47d53dcb
commit f7f3349d6a
8 changed files with 202 additions and 50 deletions

View File

@ -30,6 +30,31 @@ defines a sys_admin_pctxt with ``CAP_CHOWN``, ``CAP_DAC_OVERRIDE``,
capabilities.CAP_SYS_ADMIN], capabilities.CAP_SYS_ADMIN],
) )
Defining a context with timeout
-------------------------------
It is possible to initialize PrivContext with timeout::
from oslo_privsep import capabilities
from oslo_privsep import priv_context
dhcp_release_cmd = priv_context.PrivContext(
__name__,
cfg_section='privsep_dhcp_release',
pypath=__name__ + '.dhcp_release_cmd',
capabilities=[caps.CAP_SYS_ADMIN,
caps.CAP_NET_ADMIN],
timeout=5
)
``PrivsepTimeout`` is raised if timeout is reached.
.. warning::
The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.
Defining a privileged function Defining a privileged function
============================== ==============================
@ -51,6 +76,36 @@ generic ``update_file(filename, content)`` was created, it could be used to
overwrite any file in the filesystem, allowing easy escalation to root overwrite any file in the filesystem, allowing easy escalation to root
rights. That would defeat the whole purpose of oslo.privsep. rights. That would defeat the whole purpose of oslo.privsep.
Defining a privileged function with timeout
-------------------------------------------
It is possible to use ``entrypoint_with_timeout`` decorator::
from oslo_privsep import daemon
from neutron import privileged
@privileged.default.entrypoint_with_timeout(timeout=5)
def get_link_devices(namespace, **kwargs):
try:
with get_iproute(namespace) as ip:
return make_serializable(ip.get_links(**kwargs))
except OSError as e:
if e.errno == errno.ENOENT:
raise NetworkNamespaceNotFound(netns_name=namespace)
raise
except daemon.FailedToDropPrivileges:
raise
except daemon.PrivsepTimeout:
raise
``PrivsepTimeout`` is raised if timeout is reached.
.. warning::
The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.
Using a privileged function Using a privileged function
=========================== ===========================

View File

@ -20,6 +20,8 @@ python datatypes. Msgpack 'raw' is assumed to be a valid utf8 string
converted to tuples during serialization/deserialization. converted to tuples during serialization/deserialization.
""" """
import datetime
import enum
import logging import logging
import socket import socket
import threading import threading
@ -28,22 +30,24 @@ import msgpack
import six import six
from oslo_privsep._i18n import _ from oslo_privsep._i18n import _
from oslo_utils import uuidutils
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
try: @enum.unique
import greenlet class Message(enum.IntEnum):
"""Types of messages sent across the communication channel"""
PING = 1
PONG = 2
CALL = 3
RET = 4
ERR = 5
LOG = 6
def _get_thread_ident():
# This returns something sensible, even if the current thread
# isn't a greenthread
return id(greenlet.getcurrent())
except ImportError: class PrivsepTimeout(Exception):
def _get_thread_ident(): pass
return threading.current_thread().ident
class Serializer(object): class Serializer(object):
@ -89,10 +93,11 @@ class Deserializer(six.Iterator):
class Future(object): class Future(object):
"""A very simple object to track the return of a function call""" """A very simple object to track the return of a function call"""
def __init__(self, lock): def __init__(self, lock, timeout=None):
self.condvar = threading.Condition(lock) self.condvar = threading.Condition(lock)
self.error = None self.error = None
self.data = None self.data = None
self.timeout = timeout
def set_result(self, data): def set_result(self, data):
"""Must already be holding lock used in constructor""" """Must already be holding lock used in constructor"""
@ -106,7 +111,16 @@ class Future(object):
def result(self): def result(self):
"""Must already be holding lock used in constructor""" """Must already be holding lock used in constructor"""
self.condvar.wait() before = datetime.datetime.now()
if not self.condvar.wait(timeout=self.timeout):
now = datetime.datetime.now()
LOG.warning('Timeout while executing a command, timeout: %s, '
'time elapsed: %s', self.timeout,
(now - before).total_seconds())
return (Message.ERR.value,
'%s.%s' % (PrivsepTimeout.__module__,
PrivsepTimeout.__name__),
'')
if self.error is not None: if self.error is not None:
raise self.error raise self.error
return self.data return self.data
@ -138,8 +152,9 @@ class ClientChannel(object):
else: else:
with self.lock: with self.lock:
if msgid not in self.outstanding_msgs: if msgid not in self.outstanding_msgs:
raise AssertionError("msgid should in " LOG.warning("msgid should be in oustanding_msgs, it is"
"outstanding_msgs.") "possible that timeout is reached!")
continue
self.outstanding_msgs[msgid].set_result(data) self.outstanding_msgs[msgid].set_result(data)
# EOF. Perhaps the privileged process exited? # EOF. Perhaps the privileged process exited?
@ -158,13 +173,14 @@ class ClientChannel(object):
"""Received OOB message. Subclasses might want to override this.""" """Received OOB message. Subclasses might want to override this."""
pass pass
def send_recv(self, msg): def send_recv(self, msg, timeout=None):
myid = _get_thread_ident() myid = uuidutils.generate_uuid()
future = Future(self.lock) while myid in self.outstanding_msgs:
LOG.warning("myid shoudn't be in outstanding_msgs.")
myid = uuidutils.generate_uuid()
future = Future(self.lock, timeout)
with self.lock: with self.lock:
if myid in self.outstanding_msgs:
raise AssertionError("myid shoudn't be in outstanding_msgs.")
self.outstanding_msgs[myid] = future self.outstanding_msgs[myid] = future
try: try:
self.writer.send((myid, msg)) self.writer.send((myid, msg))

View File

@ -109,17 +109,6 @@ class StdioFd(enum.IntEnum):
STDERR = 2 STDERR = 2
@enum.unique
class Message(enum.IntEnum):
"""Types of messages sent across the communication channel"""
PING = 1
PONG = 2
CALL = 3
RET = 4
ERR = 5
LOG = 6
class FailedToDropPrivileges(Exception): class FailedToDropPrivileges(Exception):
pass pass
@ -187,7 +176,7 @@ class PrivsepLogHandler(pylogging.Handler):
data['msg'] = record.getMessage() data['msg'] = record.getMessage()
data['args'] = () data['args'] = ()
self.channel.send((None, (Message.LOG, data))) self.channel.send((None, (comm.Message.LOG, data)))
class _ClientChannel(comm.ClientChannel): class _ClientChannel(comm.ClientChannel):
@ -201,8 +190,8 @@ class _ClientChannel(comm.ClientChannel):
def exchange_ping(self): def exchange_ping(self):
try: try:
# exchange "ready" messages # exchange "ready" messages
reply = self.send_recv((Message.PING.value,)) reply = self.send_recv((comm.Message.PING.value,))
success = reply[0] == Message.PONG success = reply[0] == comm.Message.PONG
except Exception as e: except Exception as e:
self.log.exception('Error while sending initial PING to privsep: ' self.log.exception('Error while sending initial PING to privsep: '
'%s', e) '%s', e)
@ -212,12 +201,13 @@ class _ClientChannel(comm.ClientChannel):
self.log.critical(msg) self.log.critical(msg)
raise FailedToDropPrivileges(msg) raise FailedToDropPrivileges(msg)
def remote_call(self, name, args, kwargs): def remote_call(self, name, args, kwargs, timeout):
result = self.send_recv((Message.CALL.value, name, args, kwargs)) result = self.send_recv((comm.Message.CALL.value, name, args, kwargs),
if result[0] == Message.RET: timeout)
if result[0] == comm.Message.RET:
# (RET, return value) # (RET, return value)
return result[1] return result[1]
elif result[0] == Message.ERR: elif result[0] == comm.Message.ERR:
# (ERR, exc_type, args) # (ERR, exc_type, args)
# #
# TODO(gus): see what can be done to preserve traceback # TODO(gus): see what can be done to preserve traceback
@ -228,7 +218,7 @@ class _ClientChannel(comm.ClientChannel):
raise ProtocolError(_('Unexpected response: %r') % result) raise ProtocolError(_('Unexpected response: %r') % result)
def out_of_band(self, msg): def out_of_band(self, msg):
if msg[0] == Message.LOG: if msg[0] == comm.Message.LOG:
# (LOG, LogRecord __dict__) # (LOG, LogRecord __dict__)
message = {encodeutils.safe_decode(k): v message = {encodeutils.safe_decode(k): v
for k, v in msg[1].items()} for k, v in msg[1].items()}
@ -470,11 +460,11 @@ class Daemon(object):
:return: A tuple of the return status, optional call output, and :return: A tuple of the return status, optional call output, and
optional error information. optional error information.
""" """
if cmd == Message.PING: if cmd == comm.Message.PING:
return (Message.PONG.value,) return (comm.Message.PONG.value,)
try: try:
if cmd != Message.CALL: if cmd != comm.Message.CALL:
raise ProtocolError(_('Unknown privsep cmd: %s') % cmd) raise ProtocolError(_('Unknown privsep cmd: %s') % cmd)
# Extract the callable and arguments # Extract the callable and arguments
@ -485,14 +475,14 @@ class Daemon(object):
raise NameError(msg) raise NameError(msg)
ret = func(*f_args, **f_kwargs) ret = func(*f_args, **f_kwargs)
return (Message.RET.value, ret) return (comm.Message.RET.value, ret)
except Exception as e: except Exception as e:
LOG.debug( LOG.debug(
'privsep: Exception during request[%(msgid)s]: ' 'privsep: Exception during request[%(msgid)s]: '
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True) '%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
cls = e.__class__ cls = e.__class__
cls_name = '%s.%s' % (cls.__module__, cls.__name__) cls_name = '%s.%s' % (cls.__module__, cls.__name__)
return (Message.ERR.value, cls_name, e.args) return (comm.Message.ERR.value, cls_name, e.args)
def _create_done_callback(self, msgid): def _create_done_callback(self, msgid):
"""Creates a future callback to receive command execution results. """Creates a future callback to receive command execution results.
@ -520,7 +510,7 @@ class Daemon(object):
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True) '%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
cls = e.__class__ cls = e.__class__
cls_name = '%s.%s' % (cls.__module__, cls.__name__) cls_name = '%s.%s' % (cls.__module__, cls.__name__)
reply = (Message.ERR.value, cls_name, e.args) reply = (comm.Message.ERR.value, cls_name, e.args)
try: try:
channel.send((msgid, reply)) channel.send((msgid, reply))
except IOError: except IOError:

View File

@ -20,6 +20,7 @@ import unittest
from oslo_config import fixture as config_fixture from oslo_config import fixture as config_fixture
from oslotest import base from oslotest import base
from oslo_privsep import comm
from oslo_privsep import priv_context from oslo_privsep import priv_context
@ -30,6 +31,14 @@ test_context = priv_context.PrivContext(
capabilities=[], capabilities=[],
) )
test_context_with_timeout = priv_context.PrivContext(
__name__,
cfg_section='privsep',
pypath=__name__ + '.test_context_with_timeout',
capabilities=[],
timeout=0.03
)
@test_context.entrypoint @test_context.entrypoint
def sleep(): def sleep():
@ -37,6 +46,18 @@ def sleep():
time.sleep(.001) time.sleep(.001)
@test_context.entrypoint_with_timeout(0.03)
def sleep_with_timeout(long_timeout=0.04):
time.sleep(long_timeout)
return 42
@test_context_with_timeout.entrypoint
def sleep_with_t_context(long_timeout=0.04):
time.sleep(long_timeout)
return 42
@test_context.entrypoint @test_context.entrypoint
def one(): def one():
return 1 return 1
@ -65,6 +86,28 @@ class TestDaemon(base.BaseTestCase):
# Make sure the daemon is still working # Make sure the daemon is still working
self.assertEqual(1, one()) self.assertEqual(1, one())
def test_entrypoint_with_timeout(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
self.assertRaises(comm.PrivsepTimeout, sleep_with_timeout)
def test_entrypoint_with_timeout_pass(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
res = sleep_with_timeout(0.01)
self.assertEqual(42, res)
def test_context_with_timeout(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
self.assertRaises(comm.PrivsepTimeout, sleep_with_t_context)
def test_context_with_timeout_pass(self):
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
for _ in range(thread_pool_size + 1):
res = sleep_with_t_context(0.01)
self.assertEqual(42, res)
def test_logging(self): def test_logging(self):
logs() logs()
self.assertIn('foo', self.log_fixture.logger.output) self.assertIn('foo', self.log_fixture.logger.output)

View File

@ -128,7 +128,8 @@ def init(root_helper=None):
class PrivContext(object): class PrivContext(object):
def __init__(self, prefix, cfg_section='privsep', pypath=None, def __init__(self, prefix, cfg_section='privsep', pypath=None,
capabilities=None, logger_name='oslo_privsep.daemon'): capabilities=None, logger_name='oslo_privsep.daemon',
timeout=None):
# Note that capabilities=[] means retaining no capabilities # Note that capabilities=[] means retaining no capabilities
# and leaves even uid=0 with no powers except being able to # and leaves even uid=0 with no powers except being able to
@ -156,6 +157,7 @@ class PrivContext(object):
default=capabilities) default=capabilities)
cfg.CONF.set_default('logger_name', group=cfg_section, cfg.CONF.set_default('logger_name', group=cfg_section,
default=logger_name) default=logger_name)
self.timeout = timeout
@property @property
def conf(self): def conf(self):
@ -221,7 +223,22 @@ class PrivContext(object):
def entrypoint(self, func): def entrypoint(self, func):
"""This is intended to be used as a decorator.""" """This is intended to be used as a decorator."""
return self._entrypoint(func)
def entrypoint_with_timeout(self, timeout):
"""This is intended to be used as a decorator with timeout."""
def wrap(func):
@functools.wraps(func)
def inner(*args, **kwargs):
f = self._entrypoint(func)
return f(*args, _wrap_timeout=timeout, **kwargs)
setattr(inner, _ENTRYPOINT_ATTR, self)
return inner
return wrap
def _entrypoint(self, func):
if not func.__module__.startswith(self.prefix): if not func.__module__.startswith(self.prefix):
raise AssertionError('%r entrypoints must be below "%s"' % raise AssertionError('%r entrypoints must be below "%s"' %
(self, self.prefix)) (self, self.prefix))
@ -242,7 +259,7 @@ class PrivContext(object):
def is_entrypoint(self, func): def is_entrypoint(self, func):
return getattr(func, _ENTRYPOINT_ATTR, None) is self return getattr(func, _ENTRYPOINT_ATTR, None) is self
def _wrap(self, func, *args, **kwargs): def _wrap(self, func, *args, _wrap_timeout=None, **kwargs):
if self.client_mode: if self.client_mode:
name = '%s.%s' % (func.__module__, func.__name__) name = '%s.%s' % (func.__module__, func.__name__)
if self.channel is not None and not self.channel.running: if self.channel is not None and not self.channel.running:
@ -250,7 +267,9 @@ class PrivContext(object):
self.stop() self.stop()
if self.channel is None: if self.channel is None:
self.start() self.start()
return self.channel.remote_call(name, args, kwargs) r_call_timeout = _wrap_timeout or self.timeout
return self.channel.remote_call(name, args, kwargs,
r_call_timeout)
else: else:
return func(*args, **kwargs) return func(*args, **kwargs)

View File

@ -216,7 +216,7 @@ class ClientChannelTestCase(base.BaseTestCase):
@mock.patch.object(daemon.LOG.logger, 'handle') @mock.patch.object(daemon.LOG.logger, 'handle')
def test_out_of_band_log_message(self, handle_mock): def test_out_of_band_log_message(self, handle_mock):
message = [daemon.Message.LOG, self.DICT] message = [comm.Message.LOG, self.DICT]
self.assertEqual(self.client_channel.log, daemon.LOG) self.assertEqual(self.client_channel.log, daemon.LOG)
with mock.patch.object(pylogging, 'makeLogRecord') as mock_make_log, \ with mock.patch.object(pylogging, 'makeLogRecord') as mock_make_log, \
mock.patch.object(daemon.LOG, 'isEnabledFor', mock.patch.object(daemon.LOG, 'isEnabledFor',
@ -229,7 +229,7 @@ class ClientChannelTestCase(base.BaseTestCase):
def test_out_of_band_not_log_message(self): def test_out_of_band_not_log_message(self):
with mock.patch.object(daemon.LOG, 'warning') as mock_warning: with mock.patch.object(daemon.LOG, 'warning') as mock_warning:
self.client_channel.out_of_band([daemon.Message.PING]) self.client_channel.out_of_band([comm.Message.PING])
mock_warning.assert_called_once() mock_warning.assert_called_once()
@mock.patch.object(daemon.logging, 'getLogger') @mock.patch.object(daemon.logging, 'getLogger')
@ -245,7 +245,7 @@ class ClientChannelTestCase(base.BaseTestCase):
get_logger_mock.assert_called_once_with(logger_name) get_logger_mock.assert_called_once_with(logger_name)
self.assertEqual(get_logger_mock.return_value, channel.log) self.assertEqual(get_logger_mock.return_value, channel.log)
message = [daemon.Message.LOG, self.DICT] message = [comm.Message.LOG, self.DICT]
channel.out_of_band(message) channel.out_of_band(message)
make_log_mock.assert_called_once_with(self.EXPECTED) make_log_mock.assert_called_once_with(self.EXPECTED)

View File

@ -19,10 +19,12 @@ import pipes
import platform import platform
import sys import sys
import tempfile import tempfile
import time
from unittest import mock from unittest import mock
import testtools import testtools
from oslo_privsep import comm
from oslo_privsep import daemon from oslo_privsep import daemon
from oslo_privsep import priv_context from oslo_privsep import priv_context
from oslo_privsep.tests import testctx from oslo_privsep.tests import testctx
@ -40,6 +42,12 @@ def add1(arg):
return arg + 1 return arg + 1
@testctx.context.entrypoint_with_timeout(0.2)
def do_some_long(long_timeout=0.4):
time.sleep(long_timeout)
return 42
class CustomError(Exception): class CustomError(Exception):
def __init__(self, code, msg): def __init__(self, code, msg):
super(CustomError, self).__init__(code, msg) super(CustomError, self).__init__(code, msg)
@ -188,6 +196,16 @@ class RootwrapTest(testctx.TestContextTestCase):
priv_pid = priv_getpid() priv_pid = priv_getpid()
self.assertNotMyPid(priv_pid) self.assertNotMyPid(priv_pid)
def test_long_call_with_timeout(self):
self.assertRaises(
comm.PrivsepTimeout,
do_some_long
)
def test_long_call_within_timeout(self):
res = do_some_long(0.001)
self.assertEqual(42, res)
@testtools.skipIf(platform.system() != 'Linux', @testtools.skipIf(platform.system() != 'Linux',
'works only on Linux platform.') 'works only on Linux platform.')

View File

@ -0,0 +1,11 @@
---
features:
- |
Add ``timeout`` as parameter to ``PrivContext`` and add
``entrypoint_with_timeout`` decorator to cover the issues with
commands which take random time to finish.
``PrivsepTimeout`` is raised if timeout is reached.
``Warning``: The daemon (the root process) task won't stop when timeout
is reached. That means we'll have less available threads if the related
thread never finishes.