Fixed main thread patching and moved the code into patcher.py. Implemented a sys.modules state saver that simplifies much of the code in patcher.py. Improved comments and docs in there too.

This commit is contained in:
Ryan Williams
2010-06-04 13:33:14 -07:00
parent 88bf9d0240
commit 0793b503eb
2 changed files with 79 additions and 56 deletions

View File

@@ -11,13 +11,3 @@ patcher.inject('threading',
('time', time))
del patcher
def _patch_main_thread(mod):
# this is some gnarly patching for the threading module;
# if threading is imported before we patch (it nearly always is),
# then the main thread will have the wrong key in threading._active,
# so, we try and replace that key with the correct one here
# this works best if there are no other threads besides the main one
curthread = mod._active.pop(mod._get_ident(), None)
if curthread:
mod._active[thread.get_ident()] = curthread

View File

@@ -5,6 +5,33 @@ __all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched']
__exclude = set(('__builtins__', '__file__', '__name__'))
class SysModulesSaver(object):
"""Class that captures some subset of the current state of
sys.modules. Pass in an iterator of module names to the
constructor."""
def __init__(self, module_names=()):
self._saved = {}
self.save(*module_names)
def save(self, *module_names):
"""Saves the named modules to the object."""
for modname in module_names:
self._saved[modname] = sys.modules.get(modname, None)
def restore(self):
"""Restores the modules that the saver knows about into
sys.modules.
"""
for modname, mod in self._saved.iteritems():
if mod is not None:
sys.modules[modname] = mod
else:
try:
del sys.modules[modname]
except KeyError:
pass
def inject(module_name, new_globals, *additional_modules):
"""Base method for "injecting" greened modules into an imported module. It
imports the module specified in *module_name*, arranging things so
@@ -34,16 +61,20 @@ def inject(module_name, new_globals, *additional_modules):
_green_socket_modules() +
_green_thread_modules() +
_green_time_modules())
# after this we are gonna screw with sys.modules, so capture the
# state of all the modules we're going to mess with
saver = SysModulesSaver([name for name, m in additional_modules])
saver.save(module_name)
## Put the specified modules in sys.modules for the duration of the import
saved = {}
# Cover the target modules so that when you import the module it
# sees only the patched versions
for name, mod in additional_modules:
saved[name] = sys.modules.get(name, None)
sys.modules[name] = mod
## Remove the old module from sys.modules and reimport it while
## the specified modules are in place
old_module = sys.modules.pop(module_name, None)
sys.modules.pop(module_name, None)
try:
module = __import__(module_name, {}, {}, module_name.split('.')[:-1])
@@ -56,18 +87,7 @@ def inject(module_name, new_globals, *additional_modules):
## Keep a reference to the new module to prevent it from dying
sys.modules[patched_name] = module
finally:
## Put the original module back
if old_module is not None:
sys.modules[module_name] = old_module
elif module_name in sys.modules:
del sys.modules[module_name]
## Put all the saved modules back
for name, mod in additional_modules:
if saved[name] is not None:
sys.modules[name] = saved[name]
else:
del sys.modules[name]
saver.restore() ## Put the original modules back
return module
@@ -86,8 +106,11 @@ def import_patched(module_name, *additional_modules, **kw_additional_modules):
def patch_function(func, *additional_modules):
"""Huge hack here -- patches the specified modules for the
duration of the function call."""
"""Decorator that returns a version of the function that patches
some modules for the duration of the function call. This is
deeply gross and should only be used for functions that import
network libraries within their function bodies that there is no
way of getting around."""
if not additional_modules:
# supply some defaults
additional_modules = (
@@ -98,50 +121,44 @@ def patch_function(func, *additional_modules):
_green_time_modules())
def patched(*args, **kw):
saved = {}
saver = SysModulesSaver(additional_modules.keys())
for name, mod in additional_modules:
saved[name] = sys.modules.get(name, None)
sys.modules[name] = mod
try:
return func(*args, **kw)
finally:
## Put all the saved modules back
for name, mod in additional_modules:
if saved[name] is not None:
sys.modules[name] = saved[name]
else:
del sys.modules[name]
saver.restore()
return patched
def _original_patch_function(func, *module_names):
"""Kind of the opposite of patch_function; wraps a function such
that sys.modules is populated only with the unpatched versions of
the specified modules. Also a gross hack; tell your kids not to
import inside function bodies!"""
"""Kind of the contrapositive of patch_function: decorates a
function such that when it's called, sys.modules is populated only
with the unpatched versions of the specified modules. Unlike
patch_function, only the names of the modules need be supplied,
and there are no defaults. This is a gross hack; tell your kids not
to import inside function bodies!"""
def patched(*args, **kw):
saved = {}
saver = SysModulesSaver(module_names)
for name in module_names:
saved[name] = sys.modules.get(name, None)
sys.modules[name] = original(name)
try:
return func(*args, **kw)
finally:
for name in module_names:
if saved[name] is not None:
sys.modules[name] = saved[name]
else:
del sys.modules[name]
saver.restore()
return patched
def original(modname):
""" This returns an unpatched version of a module; this is useful for
Eventlet itself (i.e. tpool)."""
original_name = '__original_module_' + modname
if original_name in sys.modules:
return sys.modules.get(original_name)
# re-import the "pure" module and store it in the global _originals
# dict; be sure to restore whatever module had that name already
current_mod = sys.modules.pop(modname, None)
saver = SysModulesSaver((modname,))
sys.modules.pop(modname, None)
try:
real_mod = __import__(modname, {}, {}, modname.split('.')[:-1])
# hacky hack: Queue's constructor imports threading; therefore
@@ -149,12 +166,11 @@ def original(modname):
# original threading
if modname == 'Queue':
real_mod.Queue.__init__ = _original_patch_function(real_mod.Queue.__init__, 'threading')
# save a reference to the unpatched module so it doesn't get lost
sys.modules[original_name] = real_mod
finally:
if current_mod is not None:
sys.modules[modname] = current_mod
else:
del sys.modules[modname]
saver.restore()
return sys.modules[original_name]
already_patched = {}
@@ -183,6 +199,7 @@ def monkey_patch(**on):
on.setdefault(modname, default_on)
modules_to_patch = []
patched_thread = False
if on['os'] and not already_patched.get('os'):
modules_to_patch += _green_os_modules()
already_patched['os'] = True
@@ -193,10 +210,7 @@ def monkey_patch(**on):
modules_to_patch += _green_socket_modules()
already_patched['socket'] = True
if on['thread'] and not already_patched.get('thread'):
# hacks ahead
threading = original('threading')
import eventlet.green.threading as greenthreading
greenthreading._patch_main_thread(threading)
patched_thread = True
modules_to_patch += _green_thread_modules()
already_patched['thread'] = True
if on['time'] and not already_patched.get('time'):
@@ -222,6 +236,25 @@ def monkey_patch(**on):
if patched_attr is not None:
setattr(orig_mod, attr_name, patched_attr)
# hacks ahead; this is necessary to prevent a KeyError on program exit
if patched_thread:
_patch_main_thread(sys.modules['threading'])
def _patch_main_thread(mod):
"""This is some gnarly patching specific to the threading module;
threading will always be initialized prior to monkeypatching, and
its _active dict will have the wrong key (it uses the real thread
id but once it's patched it will use the greenlet ids); so what we
do is rekey the _active dict so that the main thread's entry uses
the greenthread key. Other threads' keys are ignored."""
thread = original('thread')
curthread = mod._active.pop(thread.get_ident(), None)
if curthread:
import eventlet.green.thread
mod._active[eventlet.green.thread.get_ident()] = curthread
def is_monkey_patched(module):
"""Returns True if the given module is monkeypatched currently, False if
not. *module* can be either the module itself or its name.