Fixed main thread patching and moved the code into patcher.py. Implemented a sys.modules state saver that simplifies much of the code in patcher.py. Improved comments and docs in there too.
This commit is contained in:
@@ -11,13 +11,3 @@ patcher.inject('threading',
|
||||
('time', time))
|
||||
|
||||
del patcher
|
||||
|
||||
def _patch_main_thread(mod):
|
||||
# this is some gnarly patching for the threading module;
|
||||
# if threading is imported before we patch (it nearly always is),
|
||||
# then the main thread will have the wrong key in threading._active,
|
||||
# so, we try and replace that key with the correct one here
|
||||
# this works best if there are no other threads besides the main one
|
||||
curthread = mod._active.pop(mod._get_ident(), None)
|
||||
if curthread:
|
||||
mod._active[thread.get_ident()] = curthread
|
||||
|
@@ -5,6 +5,33 @@ __all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched']
|
||||
|
||||
__exclude = set(('__builtins__', '__file__', '__name__'))
|
||||
|
||||
class SysModulesSaver(object):
|
||||
"""Class that captures some subset of the current state of
|
||||
sys.modules. Pass in an iterator of module names to the
|
||||
constructor."""
|
||||
def __init__(self, module_names=()):
|
||||
self._saved = {}
|
||||
self.save(*module_names)
|
||||
|
||||
def save(self, *module_names):
|
||||
"""Saves the named modules to the object."""
|
||||
for modname in module_names:
|
||||
self._saved[modname] = sys.modules.get(modname, None)
|
||||
|
||||
def restore(self):
|
||||
"""Restores the modules that the saver knows about into
|
||||
sys.modules.
|
||||
"""
|
||||
for modname, mod in self._saved.iteritems():
|
||||
if mod is not None:
|
||||
sys.modules[modname] = mod
|
||||
else:
|
||||
try:
|
||||
del sys.modules[modname]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def inject(module_name, new_globals, *additional_modules):
|
||||
"""Base method for "injecting" greened modules into an imported module. It
|
||||
imports the module specified in *module_name*, arranging things so
|
||||
@@ -34,16 +61,20 @@ def inject(module_name, new_globals, *additional_modules):
|
||||
_green_socket_modules() +
|
||||
_green_thread_modules() +
|
||||
_green_time_modules())
|
||||
|
||||
# after this we are gonna screw with sys.modules, so capture the
|
||||
# state of all the modules we're going to mess with
|
||||
saver = SysModulesSaver([name for name, m in additional_modules])
|
||||
saver.save(module_name)
|
||||
|
||||
## Put the specified modules in sys.modules for the duration of the import
|
||||
saved = {}
|
||||
# Cover the target modules so that when you import the module it
|
||||
# sees only the patched versions
|
||||
for name, mod in additional_modules:
|
||||
saved[name] = sys.modules.get(name, None)
|
||||
sys.modules[name] = mod
|
||||
|
||||
## Remove the old module from sys.modules and reimport it while
|
||||
## the specified modules are in place
|
||||
old_module = sys.modules.pop(module_name, None)
|
||||
sys.modules.pop(module_name, None)
|
||||
try:
|
||||
module = __import__(module_name, {}, {}, module_name.split('.')[:-1])
|
||||
|
||||
@@ -56,18 +87,7 @@ def inject(module_name, new_globals, *additional_modules):
|
||||
## Keep a reference to the new module to prevent it from dying
|
||||
sys.modules[patched_name] = module
|
||||
finally:
|
||||
## Put the original module back
|
||||
if old_module is not None:
|
||||
sys.modules[module_name] = old_module
|
||||
elif module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
## Put all the saved modules back
|
||||
for name, mod in additional_modules:
|
||||
if saved[name] is not None:
|
||||
sys.modules[name] = saved[name]
|
||||
else:
|
||||
del sys.modules[name]
|
||||
saver.restore() ## Put the original modules back
|
||||
|
||||
return module
|
||||
|
||||
@@ -86,8 +106,11 @@ def import_patched(module_name, *additional_modules, **kw_additional_modules):
|
||||
|
||||
|
||||
def patch_function(func, *additional_modules):
|
||||
"""Huge hack here -- patches the specified modules for the
|
||||
duration of the function call."""
|
||||
"""Decorator that returns a version of the function that patches
|
||||
some modules for the duration of the function call. This is
|
||||
deeply gross and should only be used for functions that import
|
||||
network libraries within their function bodies that there is no
|
||||
way of getting around."""
|
||||
if not additional_modules:
|
||||
# supply some defaults
|
||||
additional_modules = (
|
||||
@@ -98,50 +121,44 @@ def patch_function(func, *additional_modules):
|
||||
_green_time_modules())
|
||||
|
||||
def patched(*args, **kw):
|
||||
saved = {}
|
||||
saver = SysModulesSaver(additional_modules.keys())
|
||||
for name, mod in additional_modules:
|
||||
saved[name] = sys.modules.get(name, None)
|
||||
sys.modules[name] = mod
|
||||
try:
|
||||
return func(*args, **kw)
|
||||
finally:
|
||||
## Put all the saved modules back
|
||||
for name, mod in additional_modules:
|
||||
if saved[name] is not None:
|
||||
sys.modules[name] = saved[name]
|
||||
else:
|
||||
del sys.modules[name]
|
||||
saver.restore()
|
||||
return patched
|
||||
|
||||
def _original_patch_function(func, *module_names):
|
||||
"""Kind of the opposite of patch_function; wraps a function such
|
||||
that sys.modules is populated only with the unpatched versions of
|
||||
the specified modules. Also a gross hack; tell your kids not to
|
||||
import inside function bodies!"""
|
||||
"""Kind of the contrapositive of patch_function: decorates a
|
||||
function such that when it's called, sys.modules is populated only
|
||||
with the unpatched versions of the specified modules. Unlike
|
||||
patch_function, only the names of the modules need be supplied,
|
||||
and there are no defaults. This is a gross hack; tell your kids not
|
||||
to import inside function bodies!"""
|
||||
def patched(*args, **kw):
|
||||
saved = {}
|
||||
saver = SysModulesSaver(module_names)
|
||||
for name in module_names:
|
||||
saved[name] = sys.modules.get(name, None)
|
||||
sys.modules[name] = original(name)
|
||||
try:
|
||||
return func(*args, **kw)
|
||||
finally:
|
||||
for name in module_names:
|
||||
if saved[name] is not None:
|
||||
sys.modules[name] = saved[name]
|
||||
else:
|
||||
del sys.modules[name]
|
||||
saver.restore()
|
||||
return patched
|
||||
|
||||
|
||||
def original(modname):
|
||||
""" This returns an unpatched version of a module; this is useful for
|
||||
Eventlet itself (i.e. tpool)."""
|
||||
original_name = '__original_module_' + modname
|
||||
if original_name in sys.modules:
|
||||
return sys.modules.get(original_name)
|
||||
|
||||
# re-import the "pure" module and store it in the global _originals
|
||||
# dict; be sure to restore whatever module had that name already
|
||||
current_mod = sys.modules.pop(modname, None)
|
||||
saver = SysModulesSaver((modname,))
|
||||
sys.modules.pop(modname, None)
|
||||
try:
|
||||
real_mod = __import__(modname, {}, {}, modname.split('.')[:-1])
|
||||
# hacky hack: Queue's constructor imports threading; therefore
|
||||
@@ -149,12 +166,11 @@ def original(modname):
|
||||
# original threading
|
||||
if modname == 'Queue':
|
||||
real_mod.Queue.__init__ = _original_patch_function(real_mod.Queue.__init__, 'threading')
|
||||
# save a reference to the unpatched module so it doesn't get lost
|
||||
sys.modules[original_name] = real_mod
|
||||
finally:
|
||||
if current_mod is not None:
|
||||
sys.modules[modname] = current_mod
|
||||
else:
|
||||
del sys.modules[modname]
|
||||
saver.restore()
|
||||
|
||||
return sys.modules[original_name]
|
||||
|
||||
already_patched = {}
|
||||
@@ -183,6 +199,7 @@ def monkey_patch(**on):
|
||||
on.setdefault(modname, default_on)
|
||||
|
||||
modules_to_patch = []
|
||||
patched_thread = False
|
||||
if on['os'] and not already_patched.get('os'):
|
||||
modules_to_patch += _green_os_modules()
|
||||
already_patched['os'] = True
|
||||
@@ -193,10 +210,7 @@ def monkey_patch(**on):
|
||||
modules_to_patch += _green_socket_modules()
|
||||
already_patched['socket'] = True
|
||||
if on['thread'] and not already_patched.get('thread'):
|
||||
# hacks ahead
|
||||
threading = original('threading')
|
||||
import eventlet.green.threading as greenthreading
|
||||
greenthreading._patch_main_thread(threading)
|
||||
patched_thread = True
|
||||
modules_to_patch += _green_thread_modules()
|
||||
already_patched['thread'] = True
|
||||
if on['time'] and not already_patched.get('time'):
|
||||
@@ -222,6 +236,25 @@ def monkey_patch(**on):
|
||||
if patched_attr is not None:
|
||||
setattr(orig_mod, attr_name, patched_attr)
|
||||
|
||||
# hacks ahead; this is necessary to prevent a KeyError on program exit
|
||||
if patched_thread:
|
||||
_patch_main_thread(sys.modules['threading'])
|
||||
|
||||
|
||||
def _patch_main_thread(mod):
|
||||
"""This is some gnarly patching specific to the threading module;
|
||||
threading will always be initialized prior to monkeypatching, and
|
||||
its _active dict will have the wrong key (it uses the real thread
|
||||
id but once it's patched it will use the greenlet ids); so what we
|
||||
do is rekey the _active dict so that the main thread's entry uses
|
||||
the greenthread key. Other threads' keys are ignored."""
|
||||
thread = original('thread')
|
||||
curthread = mod._active.pop(thread.get_ident(), None)
|
||||
if curthread:
|
||||
import eventlet.green.thread
|
||||
mod._active[eventlet.green.thread.get_ident()] = curthread
|
||||
|
||||
|
||||
def is_monkey_patched(module):
|
||||
"""Returns True if the given module is monkeypatched currently, False if
|
||||
not. *module* can be either the module itself or its name.
|
||||
|
Reference in New Issue
Block a user