Add timeout to PrivContext and entrypoint_with_timeout decorator
entrypoint_with_timeout decorator can be used with a timeout parameter, if the timeout is reached PrivsepTimeout is raised. The PrivContext has timeout variable, which will be used for all functions decorated with entrypoint, and PrivsepTimeout is raised if timeout is reached. Co-authored-by: Rodolfo Alonso <ralonsoh@redhat.com> Change-Id: Ie3b1fc255c0c05fd5403b90ef49b954fe397fb77 Related-Bug: #1930401
This commit is contained in:
parent
fa47d53dcb
commit
f7f3349d6a
@ -30,6 +30,31 @@ defines a sys_admin_pctxt with ``CAP_CHOWN``, ``CAP_DAC_OVERRIDE``,
|
|||||||
capabilities.CAP_SYS_ADMIN],
|
capabilities.CAP_SYS_ADMIN],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Defining a context with timeout
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
It is possible to initialize PrivContext with timeout::
|
||||||
|
|
||||||
|
from oslo_privsep import capabilities
|
||||||
|
from oslo_privsep import priv_context
|
||||||
|
|
||||||
|
dhcp_release_cmd = priv_context.PrivContext(
|
||||||
|
__name__,
|
||||||
|
cfg_section='privsep_dhcp_release',
|
||||||
|
pypath=__name__ + '.dhcp_release_cmd',
|
||||||
|
capabilities=[caps.CAP_SYS_ADMIN,
|
||||||
|
caps.CAP_NET_ADMIN],
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
|
||||||
|
``PrivsepTimeout`` is raised if timeout is reached.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
The daemon (the root process) task won't stop when timeout
|
||||||
|
is reached. That means we'll have less available threads if the related
|
||||||
|
thread never finishes.
|
||||||
|
|
||||||
Defining a privileged function
|
Defining a privileged function
|
||||||
==============================
|
==============================
|
||||||
|
|
||||||
@ -51,6 +76,36 @@ generic ``update_file(filename, content)`` was created, it could be used to
|
|||||||
overwrite any file in the filesystem, allowing easy escalation to root
|
overwrite any file in the filesystem, allowing easy escalation to root
|
||||||
rights. That would defeat the whole purpose of oslo.privsep.
|
rights. That would defeat the whole purpose of oslo.privsep.
|
||||||
|
|
||||||
|
Defining a privileged function with timeout
|
||||||
|
-------------------------------------------
|
||||||
|
|
||||||
|
It is possible to use ``entrypoint_with_timeout`` decorator::
|
||||||
|
|
||||||
|
from oslo_privsep import daemon
|
||||||
|
|
||||||
|
from neutron import privileged
|
||||||
|
|
||||||
|
@privileged.default.entrypoint_with_timeout(timeout=5)
|
||||||
|
def get_link_devices(namespace, **kwargs):
|
||||||
|
try:
|
||||||
|
with get_iproute(namespace) as ip:
|
||||||
|
return make_serializable(ip.get_links(**kwargs))
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.ENOENT:
|
||||||
|
raise NetworkNamespaceNotFound(netns_name=namespace)
|
||||||
|
raise
|
||||||
|
except daemon.FailedToDropPrivileges:
|
||||||
|
raise
|
||||||
|
except daemon.PrivsepTimeout:
|
||||||
|
raise
|
||||||
|
|
||||||
|
``PrivsepTimeout`` is raised if timeout is reached.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
The daemon (the root process) task won't stop when timeout
|
||||||
|
is reached. That means we'll have less available threads if the related
|
||||||
|
thread never finishes.
|
||||||
|
|
||||||
Using a privileged function
|
Using a privileged function
|
||||||
===========================
|
===========================
|
||||||
|
@ -20,6 +20,8 @@ python datatypes. Msgpack 'raw' is assumed to be a valid utf8 string
|
|||||||
converted to tuples during serialization/deserialization.
|
converted to tuples during serialization/deserialization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import threading
|
import threading
|
||||||
@ -28,22 +30,24 @@ import msgpack
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from oslo_privsep._i18n import _
|
from oslo_privsep._i18n import _
|
||||||
|
from oslo_utils import uuidutils
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
try:
|
@enum.unique
|
||||||
import greenlet
|
class Message(enum.IntEnum):
|
||||||
|
"""Types of messages sent across the communication channel"""
|
||||||
|
PING = 1
|
||||||
|
PONG = 2
|
||||||
|
CALL = 3
|
||||||
|
RET = 4
|
||||||
|
ERR = 5
|
||||||
|
LOG = 6
|
||||||
|
|
||||||
def _get_thread_ident():
|
|
||||||
# This returns something sensible, even if the current thread
|
|
||||||
# isn't a greenthread
|
|
||||||
return id(greenlet.getcurrent())
|
|
||||||
|
|
||||||
except ImportError:
|
class PrivsepTimeout(Exception):
|
||||||
def _get_thread_ident():
|
pass
|
||||||
return threading.current_thread().ident
|
|
||||||
|
|
||||||
|
|
||||||
class Serializer(object):
|
class Serializer(object):
|
||||||
@ -89,10 +93,11 @@ class Deserializer(six.Iterator):
|
|||||||
class Future(object):
|
class Future(object):
|
||||||
"""A very simple object to track the return of a function call"""
|
"""A very simple object to track the return of a function call"""
|
||||||
|
|
||||||
def __init__(self, lock):
|
def __init__(self, lock, timeout=None):
|
||||||
self.condvar = threading.Condition(lock)
|
self.condvar = threading.Condition(lock)
|
||||||
self.error = None
|
self.error = None
|
||||||
self.data = None
|
self.data = None
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
def set_result(self, data):
|
def set_result(self, data):
|
||||||
"""Must already be holding lock used in constructor"""
|
"""Must already be holding lock used in constructor"""
|
||||||
@ -106,7 +111,16 @@ class Future(object):
|
|||||||
|
|
||||||
def result(self):
|
def result(self):
|
||||||
"""Must already be holding lock used in constructor"""
|
"""Must already be holding lock used in constructor"""
|
||||||
self.condvar.wait()
|
before = datetime.datetime.now()
|
||||||
|
if not self.condvar.wait(timeout=self.timeout):
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
LOG.warning('Timeout while executing a command, timeout: %s, '
|
||||||
|
'time elapsed: %s', self.timeout,
|
||||||
|
(now - before).total_seconds())
|
||||||
|
return (Message.ERR.value,
|
||||||
|
'%s.%s' % (PrivsepTimeout.__module__,
|
||||||
|
PrivsepTimeout.__name__),
|
||||||
|
'')
|
||||||
if self.error is not None:
|
if self.error is not None:
|
||||||
raise self.error
|
raise self.error
|
||||||
return self.data
|
return self.data
|
||||||
@ -138,8 +152,9 @@ class ClientChannel(object):
|
|||||||
else:
|
else:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if msgid not in self.outstanding_msgs:
|
if msgid not in self.outstanding_msgs:
|
||||||
raise AssertionError("msgid should in "
|
LOG.warning("msgid should be in oustanding_msgs, it is"
|
||||||
"outstanding_msgs.")
|
"possible that timeout is reached!")
|
||||||
|
continue
|
||||||
self.outstanding_msgs[msgid].set_result(data)
|
self.outstanding_msgs[msgid].set_result(data)
|
||||||
|
|
||||||
# EOF. Perhaps the privileged process exited?
|
# EOF. Perhaps the privileged process exited?
|
||||||
@ -158,13 +173,14 @@ class ClientChannel(object):
|
|||||||
"""Received OOB message. Subclasses might want to override this."""
|
"""Received OOB message. Subclasses might want to override this."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def send_recv(self, msg):
|
def send_recv(self, msg, timeout=None):
|
||||||
myid = _get_thread_ident()
|
myid = uuidutils.generate_uuid()
|
||||||
future = Future(self.lock)
|
while myid in self.outstanding_msgs:
|
||||||
|
LOG.warning("myid shoudn't be in outstanding_msgs.")
|
||||||
|
myid = uuidutils.generate_uuid()
|
||||||
|
future = Future(self.lock, timeout)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if myid in self.outstanding_msgs:
|
|
||||||
raise AssertionError("myid shoudn't be in outstanding_msgs.")
|
|
||||||
self.outstanding_msgs[myid] = future
|
self.outstanding_msgs[myid] = future
|
||||||
try:
|
try:
|
||||||
self.writer.send((myid, msg))
|
self.writer.send((myid, msg))
|
||||||
|
@ -109,17 +109,6 @@ class StdioFd(enum.IntEnum):
|
|||||||
STDERR = 2
|
STDERR = 2
|
||||||
|
|
||||||
|
|
||||||
@enum.unique
|
|
||||||
class Message(enum.IntEnum):
|
|
||||||
"""Types of messages sent across the communication channel"""
|
|
||||||
PING = 1
|
|
||||||
PONG = 2
|
|
||||||
CALL = 3
|
|
||||||
RET = 4
|
|
||||||
ERR = 5
|
|
||||||
LOG = 6
|
|
||||||
|
|
||||||
|
|
||||||
class FailedToDropPrivileges(Exception):
|
class FailedToDropPrivileges(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -187,7 +176,7 @@ class PrivsepLogHandler(pylogging.Handler):
|
|||||||
data['msg'] = record.getMessage()
|
data['msg'] = record.getMessage()
|
||||||
data['args'] = ()
|
data['args'] = ()
|
||||||
|
|
||||||
self.channel.send((None, (Message.LOG, data)))
|
self.channel.send((None, (comm.Message.LOG, data)))
|
||||||
|
|
||||||
|
|
||||||
class _ClientChannel(comm.ClientChannel):
|
class _ClientChannel(comm.ClientChannel):
|
||||||
@ -201,8 +190,8 @@ class _ClientChannel(comm.ClientChannel):
|
|||||||
def exchange_ping(self):
|
def exchange_ping(self):
|
||||||
try:
|
try:
|
||||||
# exchange "ready" messages
|
# exchange "ready" messages
|
||||||
reply = self.send_recv((Message.PING.value,))
|
reply = self.send_recv((comm.Message.PING.value,))
|
||||||
success = reply[0] == Message.PONG
|
success = reply[0] == comm.Message.PONG
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.exception('Error while sending initial PING to privsep: '
|
self.log.exception('Error while sending initial PING to privsep: '
|
||||||
'%s', e)
|
'%s', e)
|
||||||
@ -212,12 +201,13 @@ class _ClientChannel(comm.ClientChannel):
|
|||||||
self.log.critical(msg)
|
self.log.critical(msg)
|
||||||
raise FailedToDropPrivileges(msg)
|
raise FailedToDropPrivileges(msg)
|
||||||
|
|
||||||
def remote_call(self, name, args, kwargs):
|
def remote_call(self, name, args, kwargs, timeout):
|
||||||
result = self.send_recv((Message.CALL.value, name, args, kwargs))
|
result = self.send_recv((comm.Message.CALL.value, name, args, kwargs),
|
||||||
if result[0] == Message.RET:
|
timeout)
|
||||||
|
if result[0] == comm.Message.RET:
|
||||||
# (RET, return value)
|
# (RET, return value)
|
||||||
return result[1]
|
return result[1]
|
||||||
elif result[0] == Message.ERR:
|
elif result[0] == comm.Message.ERR:
|
||||||
# (ERR, exc_type, args)
|
# (ERR, exc_type, args)
|
||||||
#
|
#
|
||||||
# TODO(gus): see what can be done to preserve traceback
|
# TODO(gus): see what can be done to preserve traceback
|
||||||
@ -228,7 +218,7 @@ class _ClientChannel(comm.ClientChannel):
|
|||||||
raise ProtocolError(_('Unexpected response: %r') % result)
|
raise ProtocolError(_('Unexpected response: %r') % result)
|
||||||
|
|
||||||
def out_of_band(self, msg):
|
def out_of_band(self, msg):
|
||||||
if msg[0] == Message.LOG:
|
if msg[0] == comm.Message.LOG:
|
||||||
# (LOG, LogRecord __dict__)
|
# (LOG, LogRecord __dict__)
|
||||||
message = {encodeutils.safe_decode(k): v
|
message = {encodeutils.safe_decode(k): v
|
||||||
for k, v in msg[1].items()}
|
for k, v in msg[1].items()}
|
||||||
@ -470,11 +460,11 @@ class Daemon(object):
|
|||||||
:return: A tuple of the return status, optional call output, and
|
:return: A tuple of the return status, optional call output, and
|
||||||
optional error information.
|
optional error information.
|
||||||
"""
|
"""
|
||||||
if cmd == Message.PING:
|
if cmd == comm.Message.PING:
|
||||||
return (Message.PONG.value,)
|
return (comm.Message.PONG.value,)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cmd != Message.CALL:
|
if cmd != comm.Message.CALL:
|
||||||
raise ProtocolError(_('Unknown privsep cmd: %s') % cmd)
|
raise ProtocolError(_('Unknown privsep cmd: %s') % cmd)
|
||||||
|
|
||||||
# Extract the callable and arguments
|
# Extract the callable and arguments
|
||||||
@ -485,14 +475,14 @@ class Daemon(object):
|
|||||||
raise NameError(msg)
|
raise NameError(msg)
|
||||||
|
|
||||||
ret = func(*f_args, **f_kwargs)
|
ret = func(*f_args, **f_kwargs)
|
||||||
return (Message.RET.value, ret)
|
return (comm.Message.RET.value, ret)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
'privsep: Exception during request[%(msgid)s]: '
|
'privsep: Exception during request[%(msgid)s]: '
|
||||||
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
|
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
|
||||||
cls = e.__class__
|
cls = e.__class__
|
||||||
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
|
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
|
||||||
return (Message.ERR.value, cls_name, e.args)
|
return (comm.Message.ERR.value, cls_name, e.args)
|
||||||
|
|
||||||
def _create_done_callback(self, msgid):
|
def _create_done_callback(self, msgid):
|
||||||
"""Creates a future callback to receive command execution results.
|
"""Creates a future callback to receive command execution results.
|
||||||
@ -520,7 +510,7 @@ class Daemon(object):
|
|||||||
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
|
'%(err)s', {'msgid': msgid, 'err': e}, exc_info=True)
|
||||||
cls = e.__class__
|
cls = e.__class__
|
||||||
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
|
cls_name = '%s.%s' % (cls.__module__, cls.__name__)
|
||||||
reply = (Message.ERR.value, cls_name, e.args)
|
reply = (comm.Message.ERR.value, cls_name, e.args)
|
||||||
try:
|
try:
|
||||||
channel.send((msgid, reply))
|
channel.send((msgid, reply))
|
||||||
except IOError:
|
except IOError:
|
||||||
|
@ -20,6 +20,7 @@ import unittest
|
|||||||
from oslo_config import fixture as config_fixture
|
from oslo_config import fixture as config_fixture
|
||||||
from oslotest import base
|
from oslotest import base
|
||||||
|
|
||||||
|
from oslo_privsep import comm
|
||||||
from oslo_privsep import priv_context
|
from oslo_privsep import priv_context
|
||||||
|
|
||||||
|
|
||||||
@ -30,6 +31,14 @@ test_context = priv_context.PrivContext(
|
|||||||
capabilities=[],
|
capabilities=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
test_context_with_timeout = priv_context.PrivContext(
|
||||||
|
__name__,
|
||||||
|
cfg_section='privsep',
|
||||||
|
pypath=__name__ + '.test_context_with_timeout',
|
||||||
|
capabilities=[],
|
||||||
|
timeout=0.03
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@test_context.entrypoint
|
@test_context.entrypoint
|
||||||
def sleep():
|
def sleep():
|
||||||
@ -37,6 +46,18 @@ def sleep():
|
|||||||
time.sleep(.001)
|
time.sleep(.001)
|
||||||
|
|
||||||
|
|
||||||
|
@test_context.entrypoint_with_timeout(0.03)
|
||||||
|
def sleep_with_timeout(long_timeout=0.04):
|
||||||
|
time.sleep(long_timeout)
|
||||||
|
return 42
|
||||||
|
|
||||||
|
|
||||||
|
@test_context_with_timeout.entrypoint
|
||||||
|
def sleep_with_t_context(long_timeout=0.04):
|
||||||
|
time.sleep(long_timeout)
|
||||||
|
return 42
|
||||||
|
|
||||||
|
|
||||||
@test_context.entrypoint
|
@test_context.entrypoint
|
||||||
def one():
|
def one():
|
||||||
return 1
|
return 1
|
||||||
@ -65,6 +86,28 @@ class TestDaemon(base.BaseTestCase):
|
|||||||
# Make sure the daemon is still working
|
# Make sure the daemon is still working
|
||||||
self.assertEqual(1, one())
|
self.assertEqual(1, one())
|
||||||
|
|
||||||
|
def test_entrypoint_with_timeout(self):
|
||||||
|
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||||
|
for _ in range(thread_pool_size + 1):
|
||||||
|
self.assertRaises(comm.PrivsepTimeout, sleep_with_timeout)
|
||||||
|
|
||||||
|
def test_entrypoint_with_timeout_pass(self):
|
||||||
|
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||||
|
for _ in range(thread_pool_size + 1):
|
||||||
|
res = sleep_with_timeout(0.01)
|
||||||
|
self.assertEqual(42, res)
|
||||||
|
|
||||||
|
def test_context_with_timeout(self):
|
||||||
|
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||||
|
for _ in range(thread_pool_size + 1):
|
||||||
|
self.assertRaises(comm.PrivsepTimeout, sleep_with_t_context)
|
||||||
|
|
||||||
|
def test_context_with_timeout_pass(self):
|
||||||
|
thread_pool_size = self.cfg_fixture.conf.privsep.thread_pool_size
|
||||||
|
for _ in range(thread_pool_size + 1):
|
||||||
|
res = sleep_with_t_context(0.01)
|
||||||
|
self.assertEqual(42, res)
|
||||||
|
|
||||||
def test_logging(self):
|
def test_logging(self):
|
||||||
logs()
|
logs()
|
||||||
self.assertIn('foo', self.log_fixture.logger.output)
|
self.assertIn('foo', self.log_fixture.logger.output)
|
||||||
|
@ -128,7 +128,8 @@ def init(root_helper=None):
|
|||||||
|
|
||||||
class PrivContext(object):
|
class PrivContext(object):
|
||||||
def __init__(self, prefix, cfg_section='privsep', pypath=None,
|
def __init__(self, prefix, cfg_section='privsep', pypath=None,
|
||||||
capabilities=None, logger_name='oslo_privsep.daemon'):
|
capabilities=None, logger_name='oslo_privsep.daemon',
|
||||||
|
timeout=None):
|
||||||
|
|
||||||
# Note that capabilities=[] means retaining no capabilities
|
# Note that capabilities=[] means retaining no capabilities
|
||||||
# and leaves even uid=0 with no powers except being able to
|
# and leaves even uid=0 with no powers except being able to
|
||||||
@ -156,6 +157,7 @@ class PrivContext(object):
|
|||||||
default=capabilities)
|
default=capabilities)
|
||||||
cfg.CONF.set_default('logger_name', group=cfg_section,
|
cfg.CONF.set_default('logger_name', group=cfg_section,
|
||||||
default=logger_name)
|
default=logger_name)
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def conf(self):
|
def conf(self):
|
||||||
@ -221,7 +223,22 @@ class PrivContext(object):
|
|||||||
|
|
||||||
def entrypoint(self, func):
|
def entrypoint(self, func):
|
||||||
"""This is intended to be used as a decorator."""
|
"""This is intended to be used as a decorator."""
|
||||||
|
return self._entrypoint(func)
|
||||||
|
|
||||||
|
def entrypoint_with_timeout(self, timeout):
|
||||||
|
"""This is intended to be used as a decorator with timeout."""
|
||||||
|
|
||||||
|
def wrap(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
f = self._entrypoint(func)
|
||||||
|
return f(*args, _wrap_timeout=timeout, **kwargs)
|
||||||
|
setattr(inner, _ENTRYPOINT_ATTR, self)
|
||||||
|
return inner
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
def _entrypoint(self, func):
|
||||||
if not func.__module__.startswith(self.prefix):
|
if not func.__module__.startswith(self.prefix):
|
||||||
raise AssertionError('%r entrypoints must be below "%s"' %
|
raise AssertionError('%r entrypoints must be below "%s"' %
|
||||||
(self, self.prefix))
|
(self, self.prefix))
|
||||||
@ -242,7 +259,7 @@ class PrivContext(object):
|
|||||||
def is_entrypoint(self, func):
|
def is_entrypoint(self, func):
|
||||||
return getattr(func, _ENTRYPOINT_ATTR, None) is self
|
return getattr(func, _ENTRYPOINT_ATTR, None) is self
|
||||||
|
|
||||||
def _wrap(self, func, *args, **kwargs):
|
def _wrap(self, func, *args, _wrap_timeout=None, **kwargs):
|
||||||
if self.client_mode:
|
if self.client_mode:
|
||||||
name = '%s.%s' % (func.__module__, func.__name__)
|
name = '%s.%s' % (func.__module__, func.__name__)
|
||||||
if self.channel is not None and not self.channel.running:
|
if self.channel is not None and not self.channel.running:
|
||||||
@ -250,7 +267,9 @@ class PrivContext(object):
|
|||||||
self.stop()
|
self.stop()
|
||||||
if self.channel is None:
|
if self.channel is None:
|
||||||
self.start()
|
self.start()
|
||||||
return self.channel.remote_call(name, args, kwargs)
|
r_call_timeout = _wrap_timeout or self.timeout
|
||||||
|
return self.channel.remote_call(name, args, kwargs,
|
||||||
|
r_call_timeout)
|
||||||
else:
|
else:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -216,7 +216,7 @@ class ClientChannelTestCase(base.BaseTestCase):
|
|||||||
|
|
||||||
@mock.patch.object(daemon.LOG.logger, 'handle')
|
@mock.patch.object(daemon.LOG.logger, 'handle')
|
||||||
def test_out_of_band_log_message(self, handle_mock):
|
def test_out_of_band_log_message(self, handle_mock):
|
||||||
message = [daemon.Message.LOG, self.DICT]
|
message = [comm.Message.LOG, self.DICT]
|
||||||
self.assertEqual(self.client_channel.log, daemon.LOG)
|
self.assertEqual(self.client_channel.log, daemon.LOG)
|
||||||
with mock.patch.object(pylogging, 'makeLogRecord') as mock_make_log, \
|
with mock.patch.object(pylogging, 'makeLogRecord') as mock_make_log, \
|
||||||
mock.patch.object(daemon.LOG, 'isEnabledFor',
|
mock.patch.object(daemon.LOG, 'isEnabledFor',
|
||||||
@ -229,7 +229,7 @@ class ClientChannelTestCase(base.BaseTestCase):
|
|||||||
|
|
||||||
def test_out_of_band_not_log_message(self):
|
def test_out_of_band_not_log_message(self):
|
||||||
with mock.patch.object(daemon.LOG, 'warning') as mock_warning:
|
with mock.patch.object(daemon.LOG, 'warning') as mock_warning:
|
||||||
self.client_channel.out_of_band([daemon.Message.PING])
|
self.client_channel.out_of_band([comm.Message.PING])
|
||||||
mock_warning.assert_called_once()
|
mock_warning.assert_called_once()
|
||||||
|
|
||||||
@mock.patch.object(daemon.logging, 'getLogger')
|
@mock.patch.object(daemon.logging, 'getLogger')
|
||||||
@ -245,7 +245,7 @@ class ClientChannelTestCase(base.BaseTestCase):
|
|||||||
get_logger_mock.assert_called_once_with(logger_name)
|
get_logger_mock.assert_called_once_with(logger_name)
|
||||||
self.assertEqual(get_logger_mock.return_value, channel.log)
|
self.assertEqual(get_logger_mock.return_value, channel.log)
|
||||||
|
|
||||||
message = [daemon.Message.LOG, self.DICT]
|
message = [comm.Message.LOG, self.DICT]
|
||||||
channel.out_of_band(message)
|
channel.out_of_band(message)
|
||||||
|
|
||||||
make_log_mock.assert_called_once_with(self.EXPECTED)
|
make_log_mock.assert_called_once_with(self.EXPECTED)
|
||||||
|
@ -19,10 +19,12 @@ import pipes
|
|||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import testtools
|
import testtools
|
||||||
|
|
||||||
|
from oslo_privsep import comm
|
||||||
from oslo_privsep import daemon
|
from oslo_privsep import daemon
|
||||||
from oslo_privsep import priv_context
|
from oslo_privsep import priv_context
|
||||||
from oslo_privsep.tests import testctx
|
from oslo_privsep.tests import testctx
|
||||||
@ -40,6 +42,12 @@ def add1(arg):
|
|||||||
return arg + 1
|
return arg + 1
|
||||||
|
|
||||||
|
|
||||||
|
@testctx.context.entrypoint_with_timeout(0.2)
|
||||||
|
def do_some_long(long_timeout=0.4):
|
||||||
|
time.sleep(long_timeout)
|
||||||
|
return 42
|
||||||
|
|
||||||
|
|
||||||
class CustomError(Exception):
|
class CustomError(Exception):
|
||||||
def __init__(self, code, msg):
|
def __init__(self, code, msg):
|
||||||
super(CustomError, self).__init__(code, msg)
|
super(CustomError, self).__init__(code, msg)
|
||||||
@ -188,6 +196,16 @@ class RootwrapTest(testctx.TestContextTestCase):
|
|||||||
priv_pid = priv_getpid()
|
priv_pid = priv_getpid()
|
||||||
self.assertNotMyPid(priv_pid)
|
self.assertNotMyPid(priv_pid)
|
||||||
|
|
||||||
|
def test_long_call_with_timeout(self):
|
||||||
|
self.assertRaises(
|
||||||
|
comm.PrivsepTimeout,
|
||||||
|
do_some_long
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_long_call_within_timeout(self):
|
||||||
|
res = do_some_long(0.001)
|
||||||
|
self.assertEqual(42, res)
|
||||||
|
|
||||||
|
|
||||||
@testtools.skipIf(platform.system() != 'Linux',
|
@testtools.skipIf(platform.system() != 'Linux',
|
||||||
'works only on Linux platform.')
|
'works only on Linux platform.')
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Add ``timeout`` as parameter to ``PrivContext`` and add
|
||||||
|
``entrypoint_with_timeout`` decorator to cover the issues with
|
||||||
|
commands which take random time to finish.
|
||||||
|
``PrivsepTimeout`` is raised if timeout is reached.
|
||||||
|
|
||||||
|
``Warning``: The daemon (the root process) task won't stop when timeout
|
||||||
|
is reached. That means we'll have less available threads if the related
|
||||||
|
thread never finishes.
|
Loading…
Reference in New Issue
Block a user