diff --git a/os_brick/executor.py b/os_brick/executor.py index b851aecaa..a1afdc260 100644 --- a/os_brick/executor.py +++ b/os_brick/executor.py @@ -18,7 +18,10 @@ and root_helper settings, so this provides that hook. """ +import threading + from oslo_concurrency import processutils as putils +from oslo_context import context as context_utils from oslo_utils import encodeutils from os_brick.privileged import rootwrap as priv_rootwrap @@ -60,3 +63,22 @@ class Executor(object): def set_root_helper(self, helper): self._root_helper = helper + + +class Thread(threading.Thread): + """Thread class that inherits the parent's context. + + This is useful when you are spawning a thread and want LOG entries to + display the right context information, such as the request. + """ + + def __init__(self, *args, **kwargs): + # Store the caller's context as a private variable shared among threads + self.__context__ = context_utils.get_current() + super(Thread, self).__init__(*args, **kwargs) + + def run(self): + # Store the context in the current thread's request store + if self.__context__: + self.__context__.update_store() + super(Thread, self).run() diff --git a/os_brick/initiator/connectors/iscsi.py b/os_brick/initiator/connectors/iscsi.py index df407038b..0af7efee1 100644 --- a/os_brick/initiator/connectors/iscsi.py +++ b/os_brick/initiator/connectors/iscsi.py @@ -17,7 +17,6 @@ import collections import glob import os import re -import threading import time from oslo_concurrency import lockutils @@ -27,6 +26,7 @@ from oslo_utils import excutils from oslo_utils import strutils from os_brick import exception +from os_brick import executor from os_brick.i18n import _ from os_brick import initiator from os_brick.initiator.connectors import base @@ -642,8 +642,8 @@ class ISCSIConnector(base.BaseLinuxConnector, base_iscsi.BaseISCSIConnector): for ip, iqn, lun in ips_iqns_luns: props = connection_properties.copy() props.update(target_portal=ip, target_iqn=iqn, target_lun=lun) - threads.append(threading.Thread(target=self._connect_vol, - args=(retries, props, data))) + threads.append(executor.Thread(target=self._connect_vol, + args=(retries, props, data))) for thread in threads: thread.start() diff --git a/os_brick/tests/test_executor.py b/os_brick/tests/test_executor.py index 7a12795ce..95b439e42 100644 --- a/os_brick/tests/test_executor.py +++ b/os_brick/tests/test_executor.py @@ -13,10 +13,11 @@ # License for the specific language governing permissions and limitations # under the License. -# import time +import threading import mock from oslo_concurrency import processutils as putils +from oslo_context import context as context_utils import six import testtools @@ -87,3 +88,76 @@ class TestExecutor(base.TestCase): stdout, stderr = executor._execute() self.assertEqual(u'Espa\xf1a', stdout) self.assertEqual(u'Z\xfcrich', stderr) + + +class TestThread(base.TestCase): + def _store_context(self, result): + """Stores current thread's context in result list.""" + result.append(context_utils.get_current()) + + def _run_threads(self, threads): + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + def _do_test(self, thread_class, expected, result=None): + if result is None: + result = [] + threads = [thread_class(target=self._store_context, args=[result]) + for i in range(3)] + self._run_threads(threads) + self.assertEqual([expected] * len(threads), result) + + def test_normal_thread(self): + """Test normal threads don't inherit parent's context.""" + context = context_utils.RequestContext() + context.update_store() + self._do_test(threading.Thread, None) + + def test_no_context(self, result=None): + """Test when parent has no context.""" + context_utils._request_store.context = None + self._do_test(brick_executor.Thread, None, result) + + def test_with_context(self, result=None): + """Test that our class actually inherits the context.""" + context = context_utils.RequestContext() + context.update_store() + self._do_test(brick_executor.Thread, context, result) + + def _run_test(self, test_method, test_args, result): + """Run one of the normal tests and store the result. + + Meant to be run in a different thread, thus the need to store the + result, because by the time the join call completes the test's stack + is no longer available and the exception will have been lost. + """ + try: + test_method(test_args) + result.append(True) + except Exception: + result.append(False) + raise + + def test_no_cross_mix(self): + """Test there's no shared global context between threads.""" + result = [] + contexts = [[], [], []] + threads = [threading.Thread(target=self._run_test, + args=[self.test_with_context, + contexts[0], + result]), + threading.Thread(target=self._run_test, + args=[self.test_no_context, + contexts[1], + result]), + threading.Thread(target=self._run_test, + args=[self.test_with_context, + contexts[2], + result])] + self._run_threads(threads) + # Check that all tests run without raising an exception + self.assertEqual([True, True, True], result) + # Check that the context were not shared + self.assertNotEqual(contexts[0], contexts[2])