From 16204c706d15dfc2b5949d8c467427ad6d576299 Mon Sep 17 00:00:00 2001
From: Samuel Merritt <sam@swiftstack.com>
Date: Tue, 17 Dec 2013 16:11:26 -0800
Subject: [PATCH] Preserve tracebacks from run_in_thread

Now the traceback goes all the way down to where the exception came
from, not just down to run_in_thread. Better for debugging.

Change-Id: Iac6acb843a6ecf51ea2672a563d80fa43d731f23
---
 swift/common/utils.py          |  6 ++--
 test/unit/common/test_utils.py | 53 ++++++++++++++++++++++------------
 2 files changed, 37 insertions(+), 22 deletions(-)

diff --git a/swift/common/utils.py b/swift/common/utils.py
index 78f69781a6..f19fd7a48d 100644
--- a/swift/common/utils.py
+++ b/swift/common/utils.py
@@ -2233,8 +2233,8 @@ class ThreadPool(object):
             try:
                 result = func(*args, **kwargs)
                 result_queue.put((ev, True, result))
-            except BaseException as err:
-                result_queue.put((ev, False, err))
+            except BaseException:
+                result_queue.put((ev, False, sys.exc_info()))
             finally:
                 work_queue.task_done()
                 os.write(self.wpipe, self.BYTE)
@@ -2264,7 +2264,7 @@ class ThreadPool(object):
                     if success:
                         ev.send(result)
                     else:
-                        ev.send_exception(result)
+                        ev.send_exception(*result)
                 finally:
                     queue.task_done()
 
diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py
index e7eb0b99a4..4a362b2295 100644
--- a/test/unit/common/test_utils.py
+++ b/test/unit/common/test_utils.py
@@ -33,6 +33,7 @@ from textwrap import dedent
 import tempfile
 import threading
 import time
+import traceback
 import unittest
 import fcntl
 import shutil
@@ -2578,13 +2579,8 @@ class TestThreadpool(unittest.TestCase):
         result = tp.force_run_in_thread(self._capture_args, 1, 2, bert='ernie')
         self.assertEquals(result, {'args': (1, 2),
                                    'kwargs': {'bert': 'ernie'}})
-
-        caught = False
-        try:
-            tp.force_run_in_thread(self._raise_valueerror)
-        except ValueError:
-            caught = True
-        self.assertTrue(caught)
+        self.assertRaises(ValueError, tp.force_run_in_thread,
+                          self._raise_valueerror)
 
     def test_run_in_thread_without_threads(self):
         # with zero threads, run_in_thread doesn't actually do so
@@ -2597,13 +2593,8 @@ class TestThreadpool(unittest.TestCase):
         result = tp.run_in_thread(self._capture_args, 1, 2, bert='ernie')
         self.assertEquals(result, {'args': (1, 2),
                                    'kwargs': {'bert': 'ernie'}})
-
-        caught = False
-        try:
-            tp.run_in_thread(self._raise_valueerror)
-        except ValueError:
-            caught = True
-        self.assertTrue(caught)
+        self.assertRaises(ValueError, tp.run_in_thread,
+                          self._raise_valueerror)
 
     def test_force_run_in_thread_without_threads(self):
         # with zero threads, force_run_in_thread uses eventlet.tpool
@@ -2616,12 +2607,36 @@ class TestThreadpool(unittest.TestCase):
         result = tp.force_run_in_thread(self._capture_args, 1, 2, bert='ernie')
         self.assertEquals(result, {'args': (1, 2),
                                    'kwargs': {'bert': 'ernie'}})
-        caught = False
+        self.assertRaises(ValueError, tp.force_run_in_thread,
+                          self._raise_valueerror)
+
+    def test_preserving_stack_trace_from_thread(self):
+        def gamma():
+            return 1 / 0  # ZeroDivisionError
+
+        def beta():
+            return gamma()
+
+        def alpha():
+            return beta()
+
+        tp = utils.ThreadPool(1)
         try:
-            tp.force_run_in_thread(self._raise_valueerror)
-        except ValueError:
-            caught = True
-        self.assertTrue(caught)
+            tp.run_in_thread(alpha)
+        except ZeroDivisionError:
+            # NB: format is (filename, line number, function name, text)
+            tb_func = [elem[2] for elem
+                       in traceback.extract_tb(sys.exc_traceback)]
+        else:
+            self.fail("Expected ZeroDivisionError")
+
+        self.assertEqual(tb_func[-1], "gamma")
+        self.assertEqual(tb_func[-2], "beta")
+        self.assertEqual(tb_func[-3], "alpha")
+        # omit the middle; what's important is that the start and end are
+        # included, not the exact names of helper methods
+        self.assertEqual(tb_func[1], "run_in_thread")
+        self.assertEqual(tb_func[0], "test_preserving_stack_trace_from_thread")
 
 
 class TestAuditLocationGenerator(unittest.TestCase):