python-swiftclient/test/unit/test_multithreading.py

241 lines
8.4 KiB
Python

# Copyright (c) 2010-2013 OpenStack, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import threading
import six
from concurrent.futures import as_completed
from six.moves.queue import Queue, Empty
from time import sleep
from swiftclient import multithreading as mt
from .utils import CaptureStream
class ThreadTestCase(unittest.TestCase):
def setUp(self):
super(ThreadTestCase, self).setUp()
self.got_items = Queue()
self.got_args_kwargs = Queue()
self.starting_thread_count = threading.active_count()
def _func(self, conn, item, *args, **kwargs):
self.got_items.put((conn, item))
self.got_args_kwargs.put((args, kwargs))
if item == 'sleep':
sleep(.1)
if item == 'go boom':
raise Exception('I went boom!')
return 'success'
def _create_conn(self):
return "This is a connection"
def _create_conn_fail(self):
raise Exception("This is a failed connection")
def assertQueueContains(self, queue, expected_contents):
got_contents = []
try:
while True:
got_contents.append(queue.get(timeout=0.1))
except Empty:
pass
if isinstance(expected_contents, set):
got_contents = set(got_contents)
self.assertEqual(expected_contents, got_contents)
class TestConnectionThreadPoolExecutor(ThreadTestCase):
def setUp(self):
super(TestConnectionThreadPoolExecutor, self).setUp()
self.input_queue = Queue()
self.stored_results = []
def tearDown(self):
super(TestConnectionThreadPoolExecutor, self).tearDown()
def test_submit_good_connection(self):
ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn, 1)
with ctpe as pool:
# Try submitting a job that should succeed
f = pool.submit(self._func, "succeed")
f.result()
self.assertQueueContains(
self.got_items,
[("This is a connection", "succeed")]
)
# Now a job that fails
try:
f = pool.submit(self._func, "go boom")
f.result()
except Exception as e:
self.assertEqual('I went boom!', str(e))
else:
self.fail('I never went boom!')
# Has the connection been returned to the pool?
f = pool.submit(self._func, "succeed")
f.result()
self.assertQueueContains(
self.got_items,
[
("This is a connection", "go boom"),
("This is a connection", "succeed")
]
)
def test_submit_bad_connection(self):
ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn_fail, 1)
with ctpe as pool:
# Now a connection that fails
try:
f = pool.submit(self._func, "succeed")
f.result()
except Exception as e:
self.assertEqual('This is a failed connection', str(e))
else:
self.fail('The connection did not fail')
# Make sure we don't lock up on failed connections
try:
f = pool.submit(self._func, "go boom")
f.result()
except Exception as e:
self.assertEqual('This is a failed connection', str(e))
else:
self.fail('The connection did not fail')
def test_lazy_connections(self):
ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn, 10)
with ctpe as pool:
# Submit multiple jobs sequentially - should only use 1 conn
f = pool.submit(self._func, "succeed")
f.result()
f = pool.submit(self._func, "succeed")
f.result()
f = pool.submit(self._func, "succeed")
f.result()
expected_connections = [(0, "This is a connection")]
expected_connections.extend([(x, None) for x in range(1, 10)])
self.assertQueueContains(
pool._connections, expected_connections
)
ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn, 10)
with ctpe as pool:
fs = []
f1 = pool.submit(self._func, "sleep")
f2 = pool.submit(self._func, "sleep")
f3 = pool.submit(self._func, "sleep")
fs.extend([f1, f2, f3])
expected_connections = [
(0, "This is a connection"),
(1, "This is a connection"),
(2, "This is a connection")
]
expected_connections.extend([(x, None) for x in range(3, 10)])
for f in as_completed(fs):
f.result()
self.assertQueueContains(
pool._connections, expected_connections
)
class TestOutputManager(unittest.TestCase):
def test_instantiation(self):
output_manager = mt.OutputManager()
self.assertEqual(sys.stdout, output_manager.print_stream)
self.assertEqual(sys.stderr, output_manager.error_stream)
def test_printers(self):
out_stream = CaptureStream(sys.stdout)
err_stream = CaptureStream(sys.stderr)
starting_thread_count = threading.active_count()
with mt.OutputManager(
print_stream=out_stream,
error_stream=err_stream) as thread_manager:
# Sanity-checking these gives power to the previous test which
# looked at the default values of thread_manager.print/error_stream
self.assertEqual(out_stream, thread_manager.print_stream)
self.assertEqual(err_stream, thread_manager.error_stream)
# No printing has happened yet, so no new threads
self.assertEqual(starting_thread_count,
threading.active_count())
thread_manager.print_msg('one-argument')
thread_manager.print_msg('one %s, %d fish', 'fish', 88)
thread_manager.error('I have %d problems, but a %s is not one',
99, u'\u062A\u062A')
thread_manager.print_msg('some\n%s\nover the %r', 'where',
u'\u062A\u062A')
thread_manager.error('one-error-argument')
thread_manager.error('Sometimes\n%.1f%% just\ndoes not\nwork!',
3.14159)
thread_manager.print_raw(
u'some raw bytes: \u062A\u062A'.encode('utf-8'))
thread_manager.print_items([
('key', 'value'),
('object', u'O\u0308bject'),
])
thread_manager.print_raw(b'\xffugly\xffraw')
# Now we have a thread for error printing and a thread for
# normal print messages
self.assertEqual(starting_thread_count + 2,
threading.active_count())
# The threads should have been cleaned up
self.assertEqual(starting_thread_count, threading.active_count())
if six.PY3:
over_the = "over the '\u062a\u062a'\n"
else:
over_the = "over the u'\\u062a\\u062a'\n"
# We write to the CaptureStream so no decoding is performed
self.assertEqual(''.join([
'one-argument\n',
'one fish, 88 fish\n',
'some\n', 'where\n',
over_the,
u'some raw bytes: \u062a\u062a',
' key: value\n',
u' object: O\u0308bject\n'
]).encode('utf8') + b'\xffugly\xffraw', out_stream.getvalue())
self.assertEqual(''.join([
u'I have 99 problems, but a \u062A\u062A is not one\n',
'one-error-argument\n',
'Sometimes\n', '3.1% just\n', 'does not\n', 'work!\n'
]), err_stream.getvalue().decode('utf8'))
self.assertEqual(3, thread_manager.error_count)