oslo.privsep/oslo_privsep/tests/test_priv_context.py
Angus Lees 9bf606327d Provide way to "initialise" oslo.privsep
Specifically, the goal here is to provide a default that can use
rootwrap.

This change implements a `priv_context.init` function that allows
oslo.privsep to hook into the startup of programs using oslo.privsep.
The intention is to call this function near the top of main() - after
oslo.config is available but before anything "interesting" is performed.

In this change, this init function just allows you to set the default
"run as root" prefix for helper_command to include something like
rootwrap.

In the future, it is expected to use this same call point to do other
"early" tasks like immediately forking privileged helpers and dropping
root if already running as root.

Change-Id: I3ea73e16b07a870629e7d69e897f2524d7068ae8
Partial-Bug: #1592043
2016-06-16 15:17:00 +10:00

192 lines
6.2 KiB
Python

# Copyright 2015 Rackspace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
import os
import pipes
import platform
import sys
import mock
import testtools
from oslo_privsep import daemon
from oslo_privsep import priv_context
from oslo_privsep.tests import testctx
LOG = logging.getLogger(__name__)
@testctx.context.entrypoint
def priv_getpid():
return os.getpid()
@testctx.context.entrypoint
def add1(arg):
return arg + 1
class CustomError(Exception):
def __init__(self, code, msg):
super(CustomError, self).__init__(code, msg)
self.code = code
self.msg = msg
def __str__(self):
return 'Code %s: %s' % (self.code, self.msg)
@testctx.context.entrypoint
def fail(custom=False):
if custom:
raise CustomError(42, 'omg!')
else:
raise RuntimeError("I can't let you do that Dave")
@testtools.skipIf(platform.system() != 'Linux',
'works only on Linux platform.')
class TestPrivContext(testctx.TestContextTestCase):
@mock.patch.object(priv_context, 'sys')
def test_init_windows(self, mock_sys):
mock_sys.platform = 'win32'
context = priv_context.PrivContext('test', capabilities=[])
self.assertFalse(context.client_mode)
@mock.patch.object(priv_context, 'sys')
def test_set_client_mode(self, mock_sys):
context = priv_context.PrivContext('test', capabilities=[])
self.assertTrue(context.client_mode)
context.set_client_mode(False)
self.assertFalse(context.client_mode)
# client_mode should remain to False on win32.
mock_sys.platform = 'win32'
self.assertRaises(RuntimeError, context.set_client_mode, True)
def test_helper_command(self):
self.privsep_conf.privsep.helper_command = 'foo --bar'
cmd = testctx.context.helper_command('/tmp/sockpath')
expected = [
'foo', '--bar',
'--privsep_context', testctx.context.pypath,
'--privsep_sock_path', '/tmp/sockpath',
]
self.assertEqual(expected, cmd)
def test_helper_command_default(self):
self.privsep_conf.config_file = ['/bar.conf']
cmd = testctx.context.helper_command('/tmp/sockpath')
expected = [
'sudo', 'privsep-helper',
'--config-file', '/bar.conf',
# --config-dir arg should be skipped
'--privsep_context', testctx.context.pypath,
'--privsep_sock_path', '/tmp/sockpath',
]
self.assertEqual(expected, cmd)
def test_helper_command_default_dirtoo(self):
self.privsep_conf.config_file = ['/bar.conf', '/baz.conf']
self.privsep_conf.config_dir = '/foo.d'
cmd = testctx.context.helper_command('/tmp/sockpath')
expected = [
'sudo', 'privsep-helper',
'--config-file', '/bar.conf',
'--config-file', '/baz.conf',
'--config-dir', '/foo.d',
'--privsep_context', testctx.context.pypath,
'--privsep_sock_path', '/tmp/sockpath',
]
self.assertEqual(expected, cmd)
def test_init_known_contexts(self):
self.assertEqual(testctx.context.helper_command('/sock')[:2],
['sudo', 'privsep-helper'])
priv_context.init(root_helper=['sudo', 'rootwrap'])
self.assertEqual(testctx.context.helper_command('/sock')[:3],
['sudo', 'rootwrap', 'privsep-helper'])
@testtools.skipIf(platform.system() != 'Linux',
'works only on Linux platform.')
class TestSeparation(testctx.TestContextTestCase):
def test_getpid(self):
# Verify that priv_getpid() was executed in another process.
priv_pid = priv_getpid()
self.assertNotMyPid(priv_pid)
def test_client_mode(self):
self.assertNotMyPid(priv_getpid())
self.addCleanup(testctx.context.set_client_mode, True)
testctx.context.set_client_mode(False)
# priv_getpid() should now run locally (and return our pid)
self.assertEqual(os.getpid(), priv_getpid())
testctx.context.set_client_mode(True)
# priv_getpid() should now run remotely again
self.assertNotMyPid(priv_getpid())
@testtools.skipIf(platform.system() != 'Linux',
'works only on Linux platform.')
class RootwrapTest(testctx.TestContextTestCase):
def setUp(self):
super(RootwrapTest, self).setUp()
testctx.context.stop()
# Generate a command that will run daemon.helper_main without
# requiring it to be properly installed.
cmd = [
'env',
'PYTHON_PATH=%s' % os.path.pathsep.join(sys.path),
sys.executable, daemon.__file__,
]
if LOG.isEnabledFor(logging.DEBUG):
cmd.append('--debug')
self.privsep_conf.set_override(
'helper_command', ' '.join(map(pipes.quote, cmd)),
group=testctx.context.cfg_section)
testctx.context.start(method=priv_context.Method.ROOTWRAP)
def test_getpid(self):
# Verify that priv_getpid() was executed in another process.
priv_pid = priv_getpid()
self.assertNotMyPid(priv_pid)
@testtools.skipIf(platform.system() != 'Linux',
'works only on Linux platform.')
class TestSerialization(testctx.TestContextTestCase):
def test_basic_functionality(self):
self.assertEqual(43, add1(42))
def test_raises_standard(self):
self.assertRaisesRegexp(
RuntimeError, "I can't let you do that Dave", fail)
def test_raises_custom(self):
exc = self.assertRaises(CustomError, fail, custom=True)
self.assertEqual(exc.code, 42)
self.assertEqual(exc.msg, 'omg!')