diff --git a/tests/unit/test_shell.py b/tests/unit/test_shell.py index bdba1939..9d442486 100644 --- a/tests/unit/test_shell.py +++ b/tests/unit/test_shell.py @@ -24,10 +24,10 @@ import swiftclient from swiftclient.service import SwiftError import swiftclient.shell import swiftclient.utils -from swiftclient.multithreading import OutputManager from os.path import basename, dirname from tests.unit.test_swiftclient import MockHttpTest +from tests.unit.utils import CaptureOutput if six.PY2: BUILTIN_OPEN = '__builtin__.open' @@ -562,20 +562,18 @@ class TestSubcommandHelp(unittest.TestCase): for command in swiftclient.shell.commands: help_var = 'st_%s_help' % command self.assertTrue(help_var in vars(swiftclient.shell)) - out = six.StringIO() - with mock.patch('sys.stdout', out): + with CaptureOutput() as out: argv = ['', command, '--help'] self.assertRaises(SystemExit, swiftclient.shell.main, argv) expected = vars(swiftclient.shell)[help_var] - self.assertEqual(out.getvalue().strip('\n'), expected) + self.assertEqual(out.strip('\n'), expected) def test_no_help(self): - out = six.StringIO() - with mock.patch('sys.stdout', out): + with CaptureOutput() as out: argv = ['', 'bad_command', '--help'] self.assertRaises(SystemExit, swiftclient.shell.main, argv) expected = 'no help for bad_command' - self.assertEqual(out.getvalue().strip('\n'), expected) + self.assertEqual(out.strip('\n'), expected) class TestParsing(unittest.TestCase): @@ -583,7 +581,7 @@ class TestParsing(unittest.TestCase): def setUp(self): super(TestParsing, self).setUp() self._environ_vars = {} - keys = os.environ.keys() + keys = list(os.environ.keys()) for k in keys: if (k in ('ST_KEY', 'ST_USER', 'ST_AUTH') or k.startswith('OS_')): @@ -790,21 +788,15 @@ class TestParsing(unittest.TestCase): "tenant_name": "", "tenant_id": ""} - out = six.StringIO() - err = six.StringIO() - mock_output = _make_output_manager(out, err) - with mock.patch('swiftclient.shell.OutputManager', mock_output): + with CaptureOutput() as output: args = _make_args("stat", {}, os_opts) self.assertRaises(SystemExit, swiftclient.shell.main, args) - self.assertEqual(err.getvalue().strip(), 'No tenant specified') + self.assertEqual(output.err.strip(), 'No tenant specified') - out = six.StringIO() - err = six.StringIO() - mock_output = _make_output_manager(out, err) - with mock.patch('swiftclient.shell.OutputManager', mock_output): + with CaptureOutput() as output: args = _make_args("stat", {}, os_opts, cmd_args=["testcontainer"]) self.assertRaises(SystemExit, swiftclient.shell.main, args) - self.assertEqual(err.getvalue().strip(), 'No tenant specified') + self.assertEqual(output.err.strip(), 'No tenant specified') def test_no_tenant_name_or_id_v3(self): os_opts = {"password": "secret", @@ -813,23 +805,17 @@ class TestParsing(unittest.TestCase): "tenant_name": "", "tenant_id": ""} - out = six.StringIO() - err = six.StringIO() - mock_output = _make_output_manager(out, err) - with mock.patch('swiftclient.shell.OutputManager', mock_output): + with CaptureOutput() as output: args = _make_args("stat", {"auth_version": "3"}, os_opts) self.assertRaises(SystemExit, swiftclient.shell.main, args) - self.assertEqual(err.getvalue().strip(), + self.assertEqual(output.err.strip(), 'No project name or project id specified.') - out = six.StringIO() - err = six.StringIO() - mock_output = _make_output_manager(out, err) - with mock.patch('swiftclient.shell.OutputManager', mock_output): + with CaptureOutput() as output: args = _make_args("stat", {"auth_version": "3"}, os_opts, cmd_args=["testcontainer"]) self.assertRaises(SystemExit, swiftclient.shell.main, args) - self.assertEqual(err.getvalue().strip(), + self.assertEqual(output.err.strip(), 'No project name or project id specified.') def test_insufficient_env_vars_v3(self): @@ -858,10 +844,8 @@ class TestParsing(unittest.TestCase): opts = {"help": ""} os_opts = {} args = _make_args("stat", opts, os_opts) - mock_stdout = six.StringIO() - with mock.patch('sys.stdout', mock_stdout): + with CaptureOutput() as out: self.assertRaises(SystemExit, swiftclient.shell.main, args) - out = mock_stdout.getvalue() self.assertTrue(out.find('[--key <api_key>]') > 0) self.assertEqual(-1, out.find('--os-username=<auth-user-name>')) @@ -872,20 +856,16 @@ class TestParsing(unittest.TestCase): # "username": "user", # "auth_url": "http://example.com:5000/v3"} args = _make_args("", opts, os_opts) - mock_stdout = six.StringIO() - with mock.patch('sys.stdout', mock_stdout): + with CaptureOutput() as out: self.assertRaises(SystemExit, swiftclient.shell.main, args) - out = mock_stdout.getvalue() self.assertTrue(out.find('[--key <api_key>]') > 0) self.assertEqual(-1, out.find('--os-username=<auth-user-name>')) ## --os-help return os options help opts = {} args = _make_args("", opts, os_opts) - mock_stdout = six.StringIO() - with mock.patch('sys.stdout', mock_stdout): + with CaptureOutput() as out: self.assertRaises(SystemExit, swiftclient.shell.main, args) - out = mock_stdout.getvalue() self.assertTrue(out.find('[--key <api_key>]') > 0) self.assertTrue(out.find('--os-username=<auth-user-name>') > 0) @@ -1150,18 +1130,3 @@ class TestKeystoneOptions(MockHttpTest): opts = {'auth-version': '2.0'} self._test_options(opts, os_opts) - - -def _make_output_manager(stdout, stderr): - class MockOutputManager(OutputManager): - # This class is used to mock OutputManager so that we can - # override stdout and stderr. Mocking sys.stdout & sys.stdout - # doesn't work because they are argument defaults in the - # OutputManager constructor and those defaults are pinned to - # the value of sys.stdout/stderr before we get chance to mock them. - def __init__(self, print_stream=None, error_stream=None): - super(MockOutputManager, self).__init__() - self.print_stream = stdout - self.error_stream = stderr - - return MockOutputManager diff --git a/tests/unit/utils.py b/tests/unit/utils.py index c149abf2..3cbb1606 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -12,11 +12,16 @@ # implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import sys from requests import RequestException from time import sleep import testtools +import mock +import six from six.moves import reload_module from swiftclient import client as c +from swiftclient import shell as s def fake_get_auth_keystone(os_options, exc=None, **kwargs): @@ -213,3 +218,70 @@ class MockHttpTest(testtools.TestCase): def tearDown(self): super(MockHttpTest, self).tearDown() reload_module(c) + + +class CaptureStream(object): + + def __init__(self, stream): + self.stream = stream + self._capture = six.StringIO() + self.streams = [self.stream, self._capture] + + def write(self, *args, **kwargs): + for stream in self.streams: + stream.write(*args, **kwargs) + + def writelines(self, *args, **kwargs): + for stream in self.streams: + stream.writelines(*args, **kwargs) + + def getvalue(self): + return self._capture.getvalue() + + +class CaptureOutput(object): + + def __init__(self): + self._out = CaptureStream(sys.stdout) + self._err = CaptureStream(sys.stderr) + + WrappedOutputManager = functools.partial(s.OutputManager, + print_stream=self._out, + error_stream=self._err) + self.patchers = [ + mock.patch('swiftclient.shell.OutputManager', + WrappedOutputManager), + mock.patch('sys.stdout', self._out), + mock.patch('sys.stderr', self._err), + ] + + def __enter__(self): + for patcher in self.patchers: + patcher.start() + return self + + def __exit__(self, *args, **kwargs): + for patcher in self.patchers: + patcher.stop() + + @property + def out(self): + return self._out.getvalue() + + @property + def err(self): + return self._err.getvalue() + + # act like the string captured by stdout + + def __str__(self): + return self.out + + def __len__(self): + return len(self.out) + + def __eq__(self, other): + return self.out == other + + def __getattr__(self, name): + return getattr(self.out, name)