Monkey patch threading.current_thread() as well

Fixes bug 115

Patching thread.get_ident() but not threading.current_thread() can
result in _DummyThread objects being created. These objects will
never be garbage collected and will leak memory. In a long running
process (like a daemon), this can result in a pretty significant
memory leak if it uses green threads regularly.
This commit is contained in:
Johannes Erdfelt
2012-02-29 19:22:18 +00:00
parent 5fb37d6665
commit 6a402c593a
3 changed files with 174 additions and 21 deletions

View File

@@ -1,9 +1,16 @@
"""Implements the standard threading module, using greenthreads."""
from eventlet import patcher from eventlet import patcher
from eventlet.green import thread from eventlet.green import thread
from eventlet.green import time from eventlet.green import time
from eventlet.support import greenlets as greenlet
__patched__ = ['_start_new_thread', '_allocate_lock', '_get_ident', '_sleep', __patched__ = ['_start_new_thread', '_allocate_lock', '_get_ident', '_sleep',
'local', 'stack_size', 'Lock'] 'local', 'stack_size', 'Lock', 'currentThread',
'current_thread']
__orig_threading = patcher.original('threading')
__threadlocal = __orig_threading.local()
patcher.inject('threading', patcher.inject('threading',
globals(), globals(),
@@ -11,3 +18,79 @@ patcher.inject('threading',
('time', time)) ('time', time))
del patcher del patcher
_count = 1
class _GreenThread(object):
"""Wrapper for GreenThread objects to provide Thread-like attributes
and methods"""
def __init__(self, g):
global _count
self._g = g
self._name = 'GreenThread-%d' % _count
_count += 1
def __repr__(self):
return '<_GreenThread(%s, %r)>' % (self._name, self._g)
@property
def name(self):
return self._name
def getName(self):
return self.name
get_name = getName
def join(self):
return self._g.wait()
__threading = None
def _fixup_thread(t):
# Some third-party packages (lockfile) will try to patch the
# threading.Thread class with a get_name attribute if it doesn't
# exist. Since we might return Thread objects from the original
# threading package that won't get patched, let's make sure each
# individual object gets patched too our patched threading.Thread
# class has been patched. This is why monkey patching can be bad...
global __threading
if not __threading:
__threading = __import__('threading')
if (hasattr(__threading.Thread, 'get_name') and
not hasattr(t, 'get_name')):
t.get_name = t.getName
return t
def current_thread():
g = greenlet.getcurrent()
if not g:
# Not currently in a greenthread, fall back to standard function
return _fixup_thread(__orig_threading.current_thread())
try:
active = __threadlocal.active
except AttributeError:
active = __threadlocal.active = {}
try:
t = active[id(g)]
except KeyError:
# Add green thread to active if we can clean it up on exit
def cleanup(g):
del active[id(g)]
try:
g.link(cleanup)
except AttributeError:
# Not a GreenThread type, so there's no way to hook into
# the green thread exiting. Fall back to the standard
# function then.
t = _fixup_thread(__orig_threading.current_thread())
else:
t = active[id(g)] = _GreenThread(g)
return t
currentThread = current_thread

View File

@@ -223,7 +223,6 @@ def monkey_patch(**on):
on.setdefault(modname, default_on) on.setdefault(modname, default_on)
modules_to_patch = [] modules_to_patch = []
patched_thread = False
if on['os'] and not already_patched.get('os'): if on['os'] and not already_patched.get('os'):
modules_to_patch += _green_os_modules() modules_to_patch += _green_os_modules()
already_patched['os'] = True already_patched['os'] = True
@@ -234,7 +233,6 @@ def monkey_patch(**on):
modules_to_patch += _green_socket_modules() modules_to_patch += _green_socket_modules()
already_patched['socket'] = True already_patched['socket'] = True
if on['thread'] and not already_patched.get('thread'): if on['thread'] and not already_patched.get('thread'):
patched_thread = True
modules_to_patch += _green_thread_modules() modules_to_patch += _green_thread_modules()
already_patched['thread'] = True already_patched['thread'] = True
if on['time'] and not already_patched.get('time'): if on['time'] and not already_patched.get('time'):
@@ -266,27 +264,9 @@ def monkey_patch(**on):
patched_attr = getattr(mod, attr_name, None) patched_attr = getattr(mod, attr_name, None)
if patched_attr is not None: if patched_attr is not None:
setattr(orig_mod, attr_name, patched_attr) setattr(orig_mod, attr_name, patched_attr)
# hacks ahead; this is necessary to prevent a KeyError on program exit
if patched_thread:
_patch_main_thread(sys.modules['threading'])
finally: finally:
imp.release_lock() imp.release_lock()
def _patch_main_thread(mod):
"""This is some gnarly patching specific to the threading module;
threading will always be initialized prior to monkeypatching, and
its _active dict will have the wrong key (it uses the real thread
id but once it's patched it will use the greenlet ids); so what we
do is rekey the _active dict so that the main thread's entry uses
the greenthread key. Other threads' keys are ignored."""
thread = original('thread')
curthread = mod._active.pop(thread.get_ident(), None)
if curthread:
import eventlet.green.thread
mod._active[eventlet.green.thread.get_ident()] = curthread
def is_monkey_patched(module): def is_monkey_patched(module):
"""Returns True if the given module is monkeypatched currently, False if """Returns True if the given module is monkeypatched currently, False if
not. *module* can be either the module itself or its name. not. *module* can be either the module itself or its name.

View File

@@ -293,5 +293,95 @@ print "done"
self.assertEqual(output, "done\n", output) self.assertEqual(output, "done\n", output)
class Threading(ProcessBase):
def test_orig_thread(self):
new_mod = """import eventlet
eventlet.monkey_patch()
from eventlet import patcher
import threading
_threading = patcher.original('threading')
def test():
print repr(threading.current_thread())
t = _threading.Thread(target=test)
t.start()
t.join()
print len(threading._active)
print len(_threading._active)
"""
self.write_to_tempfile("newmod", new_mod)
output, lines = self.launch_subprocess('newmod')
self.assertEqual(len(lines), 4, "\n".join(lines))
self.assert_(lines[0].startswith('<Thread'), lines[0])
self.assertEqual(lines[1], "1", lines[1])
self.assertEqual(lines[2], "1", lines[2])
def test_threading(self):
new_mod = """import eventlet
eventlet.monkey_patch()
import threading
def test():
print repr(threading.current_thread())
t = threading.Thread(target=test)
t.start()
t.join()
print len(threading._active)
"""
self.write_to_tempfile("newmod", new_mod)
output, lines = self.launch_subprocess('newmod')
self.assertEqual(len(lines), 3, "\n".join(lines))
self.assert_(lines[0].startswith('<_MainThread'), lines[0])
self.assertEqual(lines[1], "1", lines[1])
def test_tpool(self):
new_mod = """import eventlet
eventlet.monkey_patch()
from eventlet import tpool
import threading
def test():
print repr(threading.current_thread())
tpool.execute(test)
print len(threading._active)
"""
self.write_to_tempfile("newmod", new_mod)
output, lines = self.launch_subprocess('newmod')
self.assertEqual(len(lines), 3, "\n".join(lines))
self.assert_(lines[0].startswith('<Thread'), lines[0])
self.assertEqual(lines[1], "1", lines[1])
def test_greenlet(self):
new_mod = """import eventlet
eventlet.monkey_patch()
from eventlet import event
import threading
evt = event.Event()
def test():
print repr(threading.current_thread())
evt.send()
eventlet.spawn_n(test)
evt.wait()
print len(threading._active)
"""
self.write_to_tempfile("newmod", new_mod)
output, lines = self.launch_subprocess('newmod')
self.assertEqual(len(lines), 3, "\n".join(lines))
self.assert_(lines[0].startswith('<_MainThread'), lines[0])
self.assertEqual(lines[1], "1", lines[1])
def test_greenthread(self):
new_mod = """import eventlet
eventlet.monkey_patch()
import threading
def test():
print repr(threading.current_thread())
t = eventlet.spawn(test)
t.wait()
print len(threading._active)
"""
self.write_to_tempfile("newmod", new_mod)
output, lines = self.launch_subprocess('newmod')
self.assertEqual(len(lines), 3, "\n".join(lines))
self.assert_(lines[0].startswith('<_GreenThread'), lines[0])
self.assertEqual(lines[1], "1", lines[1])
if __name__ == '__main__': if __name__ == '__main__':
main() main()