Add timeout to PrivContext and entrypoint_with_timeout decorator
entrypoint_with_timeout decorator can be used with a timeout parameter, if the timeout is reached PrivsepTimeout is raised. The PrivContext has timeout variable, which will be used for all functions decorated with entrypoint, and PrivsepTimeout is raised if timeout is reached. Co-authored-by: Rodolfo Alonso <ralonsoh@redhat.com> Change-Id: Ie3b1fc255c0c05fd5403b90ef49b954fe397fb77 Related-Bug: #1930401
This commit is contained in:
parent
fa47d53dcb
commit
f7f3349d6a
@ -30,6 +30,31 @@ defines a sys_admin_pctxt with ``CAP_CHOWN``, ``CAP_DAC_OVERRIDE``,
|
||||
capabilities.CAP_SYS_ADMIN],
|
||||
)
|
||||
|
||||
Defining a context with timeout
|
||||
-------------------------------
|
||||
|
||||
It is possible to initialize PrivContext with timeout::
|
||||
|
||||
from oslo_privsep import capabilities
|
||||
from oslo_privsep import priv_context
|
||||
|
||||
dhcp_release_cmd = priv_context.PrivContext(
|
||||
__name__,
|
||||
cfg_section='privsep_dhcp_release',
|
||||
pypath=__name__ + '.dhcp_release_cmd',
|
||||
capabilities=[caps.CAP_SYS_ADMIN,
|
||||
caps.CAP_NET_ADMIN],
|
||||
timeout=5
|
||||
)
|
||||
|
||||
``PrivsepTimeout`` is raised if timeout is reached.
|
||||
|
||||
.. warning::
|
||||
|
||||
The daemon (the root process) task won't stop when timeout
|
||||
is reached. That means we'll have less available threads if the related
|
||||
thread never finishes.
|
||||
|
||||
Defining a privileged function
|
||||
==============================
|
||||
|
||||
@ -51,6 +76,36 @@ generic ``update_file(filename, content)`` was created, it could be used to
|
||||
overwrite any file in the filesystem, allowing easy escalation to root
|
||||
rights. That would defeat the whole purpose of oslo.privsep.
|
||||
|
||||
Defining a privileged function with timeout
|
||||
-------------------------------------------
|
||||
|
||||
It is possible to use ``entrypoint_with_timeout`` decorator::
|
||||
|
||||
from oslo_privsep import daemon
|
||||
|
||||
from neutron import privileged
|
||||
|
||||
@privileged.default.entrypoint_with_timeout(timeout=5)
|
||||
def get_link_devices(namespace, **kwargs):
|
||||
try:
|
||||
with get_iproute(namespace) as ip:
|
||||
return make_serializable(ip.get_links(**kwargs))
|
||||
except OSError as e:
|
||||
if e.errno == errno.ENOENT:
|
||||
raise NetworkNamespaceNotFound(netns_name=namespace)
|
||||
raise
|
||||
except daemon.FailedToDropPrivileges:
|
||||
raise
|
||||
except daemon.PrivsepTimeout:
|
||||
raise
|
||||
|
||||
``PrivsepTimeout`` is raised if timeout is reached.
|
||||
|
||||
.. warning::
|
||||
|
||||
The daemon (the root process) task won't stop when timeout
|
||||
is reached. That means we'll have less available threads if the related
|
||||
thread never finishes.
|
||||
|
||||
Using a privileged function
|
||||
===========================
|
||||
|
@ -20,6 +20,8 @@ python datatypes. Msgpack 'raw' is assumed to be a valid utf8 string
|
||||
converted to tuples during serialization/deserialization.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
@ -28,22 +30,24 @@ import msgpack
|
||||
import six
|
||||
|
||||
from oslo_privsep._i18n import _
|
||||
|
||||
from oslo_utils import uuidutils
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
import greenlet
|
||||
@enum.unique
|
||||
class Message(enum.IntEnum):
|
||||
"""Types of messages sent across the communication channel"""
|
||||
PING = 1
|
||||
PONG = 2
|
||||
CALL = 3
|
||||
RET = 4
|
||||
ERR = 5
|
||||
LOG = 6
|
||||
|
||||
def _get_thread_ident():
|
||||
# This returns something sensible, even if the current thread
|
||||
# isn't a greenthread
|
||||
return id(greenlet.getcurrent())
|
||||
|
||||
except ImportError:
|
||||
def _get_thread_ident():
|
||||
return threading.current_thread().ident
|
||||
class PrivsepTimeout(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Serializer(object):
|
||||
@ -89,10 +93,11 @@ class Deserializer(six.Iterator):
|
||||
class Future(object):
|
||||
"""A very simple object to track the return of a function call"""
|
||||
|
||||
def __init__(self, lock):
|
||||
def __init__(self, lock, timeout=None):
|
||||
self.condvar = threading.Condition(lock)
|
||||
self.error = None
|
||||
self.data = None
|
||||
self.timeout = timeout
|
||||
|
||||
def set_result(self, data):
|
||||
"""Must already be holding lock used in constructor"""
|
||||
@ -106,7 +111,16 @@ class Future(object):
|
||||
|
||||
def result(self):
|
||||
"""Must already be holding lock used in constructor"""
|
||||
self.condvar.wait()
|
||||
before = datetime.datetime.now()
|
||||
if not self.condvar.wait(timeout=self.timeout):
|
||||
now = datetime.datetime.now()
|
||||
LOG.warning('Timeout while executing a command, timeout: %s, '
|
||||
'time elapsed: %s', self.timeout,
|
||||
(now - before).total_seconds())
|
||||
return (Message.ERR.value,
|
||||
'%s.%s' % (PrivsepTimeout.__module__,
|
||||
PrivsepTimeout.__name__),
|
||||
'')
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
return self.data
|
||||
@ -138,8 +152,9 @@ class ClientChannel(object):
|
||||
else:
|
||||
with self.lock:
|
||||
if msgid not in self.outstanding_msgs:
|
||||
raise AssertionError("msgid should in "
|
||||
"outstanding_msgs.")
|
||||
LOG.warning("msgid should be in oustanding_msgs, it is"
|
||||
"possible that timeout is reached!")
|
||||
continue
|
||||
self.outstanding_msgs[msgid].set_result(data)
|
||||
|
||||
# EOF. Perhaps the privileged process exited?
|
||||
@ -158,13 +173,14 @@ class ClientChannel(object):
|
||||
"""Received OOB message. Subclasses might want to override this."""
|
||||
pass
|
||||
|
||||
def send_recv(self, msg):
|
||||
myid = _get_thread_ident()
|
||||
future = Future(self.lock)
|
||||
def send_recv(self, msg, timeout=None):
|
||||
myid = uuidutils.generate_uuid()
|
||||
while myid in self.outstanding_msgs:
|
||||
LOG.warning("myid shoudn't be in outstanding_msgs.")
|
||||
myid = uuidutils.generate_uuid()
|
||||
future = Future(self.lock, timeout)
|
||||
|
||||
with self.lock:
|
||||
if myid in self.outstanding_msgs:
|
||||
raise AssertionError("myid shoudn't be in outstanding_msgs.")
|
||||
self.outstanding_msgs[myid] = future
|
||||
try:
|
||||
self.writer.send((myid, msg))
|
||||
|
@ -109,17 +109,6 @@ class StdioFd(enum.IntEnum):
|
||||
STDERR = 2
|
||||
|
||||
|
||||
@enum.unique
|
||||
class Message(enum.IntEnum):
|
||||
"""Types of messages sent across the communication channel"""
|
||||
PING = 1
|
||||
PONG = 2
|
||||
CALL = 3
|
||||
RET = 4
|
||||
ERR = 5
|
||||
LOG = 6
|
||||
|
||||
|
||||
class FailedToDropPrivileges(Exception):
|
||||
pass
|
||||
|
||||
@ -187,7 +176,7 @@ class PrivsepLogHandler(pylogging.Handler):
|
||||
data['msg'] = record.getMessage()
|
||||
data['args'] = ()
|
||||
|
||||
self.channel.send((None, (Message.LOG, data)))
|
||||
self.channel.send((None, (comm.Message.LOG, data)))
|
||||
|
||||
|
||||
class _ClientChannel(comm.ClientChannel):
|
||||
@ -201,8 +190,8 @@ class _ClientChannel(comm.ClientChannel):
|
||||
def exchange_ping(self):
|
||||
try:
|
||||
# exchange "ready" messages
|
||||
reply = self.send_recv((Message.PING.value,))
|
||||
success = reply[0] == Message.PONG
|
||||
reply = self.send_recv((comm.Message.PING.value,))
|
||||
success = reply[0] == comm.Message.PONG
|
||||
except Exception as e:
|
||||
self.log.exception('Error while sending initial PING to privsep: '
|
||||
'%s', e)
|
||||
@ -212,12 +201,13 @@ class _ClientChannel(comm.ClientChannel):
|
||||
self.log.critical(msg)
|
||||
raise FailedToDropPrivileges(msg)
|
||||
|
||||
def remote_call(self, name, args, kwargs):
|
||||
result = self.send_recv((Message.CALL.value, name, args, kwargs))
|
||||
if result[0] == Message.RET:
|
||||
def remote_call(self, name, args, kwargs, timeout):
|
||||
result = self.send_recv((comm.Message.CALL.value, name, args, kwargs),
|
||||
timeout)
|
||||
if result[0] == comm.Message.RET:
|
||||
# (RET, return value)
|
||||
return result[1]
|
||||
elif result[0] == Message.ERR:
|
||||
elif result[0] == comm.Message.ERR:
|
||||
# (ERR, exc_type, args)
|
||||
#
|
||||
# TODO(gus): see what can be done to preserve traceback
|
||||
@ -228,7 +218,7 @@ class _ClientChannel(comm.ClientChannel):
|
||||
raise ProtocolError(_('Unexpected response: %r') % result)
|
||||
|
||||
def out_of_band(self, msg):
|
||||
if msg[0] == Message.LOG:
|
||||
if msg[0] == comm.Message.LOG:
|
||||
# (LOG, LogRecord __dict__)
|
||||
message = {encodeutils.safe_decode(k): v
|
||||
for k, v in msg[1].items()}
|
||||
@ -470,11 +460,11 @@ class Daemon(object):
|
||||
:return: A tuple of the return status, optional call output, and
|
||||
optional error information.
|
||||
"""
|
||||
if cmd == Message.PING:
|
||||
return (Message.PONG.value,)
|
||||
if cmd == comm.Message.PING:
|
||||
return (comm.Message.PONG.value,)
|
||||
|
||||
try:
|
||||
if cmd != Message.CALL:
|
||||
if cmd != comm.Message.CALL:
|
||||
raise ProtocolError(_('Unknown privsep cmd: %s') % cmd)
|
||||
|
||||
# Extract the callable and arguments
|
||||
@ -485,14 +475,14 @@ class Daemon(object):
|
||||
raise NameError(msg)
|
||||
|
||||
ret = func(*f_args, **f_kwargs)
|
||||
return (Message.RET.value, ret)
|
||||
return (comm.Message.RET.value, ret)
|
||||
except Exception as e:
|
||||
LOG.debug(
|
||||
'privsep: Exception during request[%(msgid)s]: '
|
||||
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
|
||||
cls = e.__class__
|
||||
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
|
||||
return (Message.ERR.value, cls_name, e.args)
|
||||
return (comm.Message.ERR.value, cls_name, e.args)
|
||||
|
||||
def _create_done_callback(self, msgid):
|
||||
"""Creates a future callback to receive command execution results.
|
||||
@ -520,7 +510,7 @@ class Daemon(object):
|
||||
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
|
||||
cls = e.__class__
|
||||
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
|
||||
reply = (Message.ERR.value, cls_name, e.args)
|
||||
reply = (comm.Message.ERR.value, cls_name, e.args)
|
||||
try:
|
||||
channel.send((msgid, reply))
|
||||
except IOError:
|
||||
|
@ -20,6 +20,7 @@ import unittest
|
||||
from oslo_config import fixture as config_fixture
|
||||
from oslotest import base
|
||||
|
||||
from oslo_privsep import comm
|
||||
from oslo_privsep import priv_context
|
||||
|
||||
|
||||
@ -30,6 +31,14 @@ test_context = priv_context.PrivContext(
|
||||
capabilities=[],
|
||||
)
|
||||
|
||||
test_context_with_timeout = priv_context.PrivContext(
|
||||
__name__,
|
||||
cfg_section='privsep',
|
||||
pypath=__name__ + '.test_context_with_timeout',
|
||||
capabilities=[],
|
||||
timeout=0.03
|
||||
)
|
||||
|
||||
|
||||
@test_context.entrypoint
|
||||
def sleep():
|
||||
@ -37,6 +46,18 @@ def sleep():
|
||||
time.sleep(.001)
|
||||
|
||||
|
||||
@test_context.entrypoint_with_timeout(0.03)
|
||||
def sleep_with_timeout(long_timeout=0.04):
|
||||
time.sleep(long_timeout)
|
||||
return 42
|
||||
|
||||
|
||||
@test_context_with_timeout.entrypoint
|
||||
def sleep_with_t_context(long_timeout=0.04):
|
||||
time.sleep(long_timeout)
|
||||
return 42
|
||||
|
||||
|
||||
@test_context.entrypoint
|
||||
def one():
|
||||
return 1
|
||||
@ -65,6 +86,28 @@ class TestDaemon(base.BaseTestCase):
|
||||
# Make sure the daemon is still working
|
||||
self.assertEqual(1, one())
|
||||
|
||||
def test_entrypoint_with_timeout(self):
|
||||
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||
for _ in range(thread_pool_size + 1):
|
||||
self.assertRaises(comm.PrivsepTimeout, sleep_with_timeout)
|
||||
|
||||
def test_entrypoint_with_timeout_pass(self):
|
||||
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||
for _ in range(thread_pool_size + 1):
|
||||
res = sleep_with_timeout(0.01)
|
||||
self.assertEqual(42, res)
|
||||
|
||||
def test_context_with_timeout(self):
|
||||
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||
for _ in range(thread_pool_size + 1):
|
||||
self.assertRaises(comm.PrivsepTimeout, sleep_with_t_context)
|
||||
|
||||
def test_context_with_timeout_pass(self):
|
||||
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||
for _ in range(thread_pool_size + 1):
|
||||
res = sleep_with_t_context(0.01)
|
||||
self.assertEqual(42, res)
|
||||
|
||||
def test_logging(self):
|
||||
logs()
|
||||
self.assertIn('foo', self.log_fixture.logger.output)
|
||||
|
@ -128,7 +128,8 @@ def init(root_helper=None):
|
||||
|
||||
class PrivContext(object):
|
||||
def __init__(self, prefix, cfg_section='privsep', pypath=None,
|
||||
capabilities=None, logger_name='oslo_privsep.daemon'):
|
||||
capabilities=None, logger_name='oslo_privsep.daemon',
|
||||
timeout=None):
|
||||
|
||||
# Note that capabilities=[] means retaining no capabilities
|
||||
# and leaves even uid=0 with no powers except being able to
|
||||
@ -156,6 +157,7 @@ class PrivContext(object):
|
||||
default=capabilities)
|
||||
cfg.CONF.set_default('logger_name', group=cfg_section,
|
||||
default=logger_name)
|
||||
self.timeout = timeout
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
@ -221,7 +223,22 @@ class PrivContext(object):
|
||||
|
||||
def entrypoint(self, func):
|
||||
"""This is intended to be used as a decorator."""
|
||||
return self._entrypoint(func)
|
||||
|
||||
def entrypoint_with_timeout(self, timeout):
|
||||
"""This is intended to be used as a decorator with timeout."""
|
||||
|
||||
def wrap(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
f = self._entrypoint(func)
|
||||
return f(*args, _wrap_timeout=timeout, **kwargs)
|
||||
setattr(inner, _ENTRYPOINT_ATTR, self)
|
||||
return inner
|
||||
return wrap
|
||||
|
||||
def _entrypoint(self, func):
|
||||
if not func.__module__.startswith(self.prefix):
|
||||
raise AssertionError('%r entrypoints must be below "%s"' %
|
||||
(self, self.prefix))
|
||||
@ -242,7 +259,7 @@ class PrivContext(object):
|
||||
def is_entrypoint(self, func):
|
||||
return getattr(func, _ENTRYPOINT_ATTR, None) is self
|
||||
|
||||
def _wrap(self, func, *args, **kwargs):
|
||||
def _wrap(self, func, *args, _wrap_timeout=None, **kwargs):
|
||||
if self.client_mode:
|
||||
name = '%s.%s' % (func.__module__, func.__name__)
|
||||
if self.channel is not None and not self.channel.running:
|
||||
@ -250,7 +267,9 @@ class PrivContext(object):
|
||||
self.stop()
|
||||
if self.channel is None:
|
||||
self.start()
|
||||
return self.channel.remote_call(name, args, kwargs)
|
||||
r_call_timeout = _wrap_timeout or self.timeout
|
||||
return self.channel.remote_call(name, args, kwargs,
|
||||
r_call_timeout)
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
@ -216,7 +216,7 @@ class ClientChannelTestCase(base.BaseTestCase):
|
||||
|
||||
@mock.patch.object(daemon.LOG.logger, 'handle')
|
||||
def test_out_of_band_log_message(self, handle_mock):
|
||||
message = [daemon.Message.LOG, self.DICT]
|
||||
message = [comm.Message.LOG, self.DICT]
|
||||
self.assertEqual(self.client_channel.log, daemon.LOG)
|
||||
with mock.patch.object(pylogging, 'makeLogRecord') as mock_make_log, \
|
||||
mock.patch.object(daemon.LOG, 'isEnabledFor',
|
||||
@ -229,7 +229,7 @@ class ClientChannelTestCase(base.BaseTestCase):
|
||||
|
||||
def test_out_of_band_not_log_message(self):
|
||||
with mock.patch.object(daemon.LOG, 'warning') as mock_warning:
|
||||
self.client_channel.out_of_band([daemon.Message.PING])
|
||||
self.client_channel.out_of_band([comm.Message.PING])
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
@mock.patch.object(daemon.logging, 'getLogger')
|
||||
@ -245,7 +245,7 @@ class ClientChannelTestCase(base.BaseTestCase):
|
||||
get_logger_mock.assert_called_once_with(logger_name)
|
||||
self.assertEqual(get_logger_mock.return_value, channel.log)
|
||||
|
||||
message = [daemon.Message.LOG, self.DICT]
|
||||
message = [comm.Message.LOG, self.DICT]
|
||||
channel.out_of_band(message)
|
||||
|
||||
make_log_mock.assert_called_once_with(self.EXPECTED)
|
||||
|
@ -19,10 +19,12 @@ import pipes
|
||||
import platform
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
import testtools
|
||||
|
||||
from oslo_privsep import comm
|
||||
from oslo_privsep import daemon
|
||||
from oslo_privsep import priv_context
|
||||
from oslo_privsep.tests import testctx
|
||||
@ -40,6 +42,12 @@ def add1(arg):
|
||||
return arg + 1
|
||||
|
||||
|
||||
@testctx.context.entrypoint_with_timeout(0.2)
|
||||
def do_some_long(long_timeout=0.4):
|
||||
time.sleep(long_timeout)
|
||||
return 42
|
||||
|
||||
|
||||
class CustomError(Exception):
|
||||
def __init__(self, code, msg):
|
||||
super(CustomError, self).__init__(code, msg)
|
||||
@ -188,6 +196,16 @@ class RootwrapTest(testctx.TestContextTestCase):
|
||||
priv_pid = priv_getpid()
|
||||
self.assertNotMyPid(priv_pid)
|
||||
|
||||
def test_long_call_with_timeout(self):
|
||||
self.assertRaises(
|
||||
comm.PrivsepTimeout,
|
||||
do_some_long
|
||||
)
|
||||
|
||||
def test_long_call_within_timeout(self):
|
||||
res = do_some_long(0.001)
|
||||
self.assertEqual(42, res)
|
||||
|
||||
|
||||
@testtools.skipIf(platform.system() != 'Linux',
|
||||
'works only on Linux platform.')
|
||||
|
@ -0,0 +1,11 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add ``timeout`` as parameter to ``PrivContext`` and add
|
||||
``entrypoint_with_timeout`` decorator to cover the issues with
|
||||
commands which take random time to finish.
|
||||
``PrivsepTimeout`` is raised if timeout is reached.
|
||||
|
||||
``Warning``: The daemon (the root process) task won't stop when timeout
|
||||
is reached. That means we'll have less available threads if the related
|
||||
thread never finishes.
|
Loading…
Reference in New Issue
Block a user