Fixed main thread patching and moved the code into patcher.py. Implemented a sys.modules state saver that simplifies much of the code in patcher.py. Improved comments and docs in there too.
This commit is contained in:
@@ -11,13 +11,3 @@ patcher.inject('threading',
|
|||||||
('time', time))
|
('time', time))
|
||||||
|
|
||||||
del patcher
|
del patcher
|
||||||
|
|
||||||
def _patch_main_thread(mod):
|
|
||||||
# this is some gnarly patching for the threading module;
|
|
||||||
# if threading is imported before we patch (it nearly always is),
|
|
||||||
# then the main thread will have the wrong key in threading._active,
|
|
||||||
# so, we try and replace that key with the correct one here
|
|
||||||
# this works best if there are no other threads besides the main one
|
|
||||||
curthread = mod._active.pop(mod._get_ident(), None)
|
|
||||||
if curthread:
|
|
||||||
mod._active[thread.get_ident()] = curthread
|
|
||||||
|
@@ -5,6 +5,33 @@ __all__ = ['inject', 'import_patched', 'monkey_patch', 'is_monkey_patched']
|
|||||||
|
|
||||||
__exclude = set(('__builtins__', '__file__', '__name__'))
|
__exclude = set(('__builtins__', '__file__', '__name__'))
|
||||||
|
|
||||||
|
class SysModulesSaver(object):
|
||||||
|
"""Class that captures some subset of the current state of
|
||||||
|
sys.modules. Pass in an iterator of module names to the
|
||||||
|
constructor."""
|
||||||
|
def __init__(self, module_names=()):
|
||||||
|
self._saved = {}
|
||||||
|
self.save(*module_names)
|
||||||
|
|
||||||
|
def save(self, *module_names):
|
||||||
|
"""Saves the named modules to the object."""
|
||||||
|
for modname in module_names:
|
||||||
|
self._saved[modname] = sys.modules.get(modname, None)
|
||||||
|
|
||||||
|
def restore(self):
|
||||||
|
"""Restores the modules that the saver knows about into
|
||||||
|
sys.modules.
|
||||||
|
"""
|
||||||
|
for modname, mod in self._saved.iteritems():
|
||||||
|
if mod is not None:
|
||||||
|
sys.modules[modname] = mod
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
del sys.modules[modname]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def inject(module_name, new_globals, *additional_modules):
|
def inject(module_name, new_globals, *additional_modules):
|
||||||
"""Base method for "injecting" greened modules into an imported module. It
|
"""Base method for "injecting" greened modules into an imported module. It
|
||||||
imports the module specified in *module_name*, arranging things so
|
imports the module specified in *module_name*, arranging things so
|
||||||
@@ -34,16 +61,20 @@ def inject(module_name, new_globals, *additional_modules):
|
|||||||
_green_socket_modules() +
|
_green_socket_modules() +
|
||||||
_green_thread_modules() +
|
_green_thread_modules() +
|
||||||
_green_time_modules())
|
_green_time_modules())
|
||||||
|
|
||||||
|
# after this we are gonna screw with sys.modules, so capture the
|
||||||
|
# state of all the modules we're going to mess with
|
||||||
|
saver = SysModulesSaver([name for name, m in additional_modules])
|
||||||
|
saver.save(module_name)
|
||||||
|
|
||||||
## Put the specified modules in sys.modules for the duration of the import
|
# Cover the target modules so that when you import the module it
|
||||||
saved = {}
|
# sees only the patched versions
|
||||||
for name, mod in additional_modules:
|
for name, mod in additional_modules:
|
||||||
saved[name] = sys.modules.get(name, None)
|
|
||||||
sys.modules[name] = mod
|
sys.modules[name] = mod
|
||||||
|
|
||||||
## Remove the old module from sys.modules and reimport it while
|
## Remove the old module from sys.modules and reimport it while
|
||||||
## the specified modules are in place
|
## the specified modules are in place
|
||||||
old_module = sys.modules.pop(module_name, None)
|
sys.modules.pop(module_name, None)
|
||||||
try:
|
try:
|
||||||
module = __import__(module_name, {}, {}, module_name.split('.')[:-1])
|
module = __import__(module_name, {}, {}, module_name.split('.')[:-1])
|
||||||
|
|
||||||
@@ -56,18 +87,7 @@ def inject(module_name, new_globals, *additional_modules):
|
|||||||
## Keep a reference to the new module to prevent it from dying
|
## Keep a reference to the new module to prevent it from dying
|
||||||
sys.modules[patched_name] = module
|
sys.modules[patched_name] = module
|
||||||
finally:
|
finally:
|
||||||
## Put the original module back
|
saver.restore() ## Put the original modules back
|
||||||
if old_module is not None:
|
|
||||||
sys.modules[module_name] = old_module
|
|
||||||
elif module_name in sys.modules:
|
|
||||||
del sys.modules[module_name]
|
|
||||||
|
|
||||||
## Put all the saved modules back
|
|
||||||
for name, mod in additional_modules:
|
|
||||||
if saved[name] is not None:
|
|
||||||
sys.modules[name] = saved[name]
|
|
||||||
else:
|
|
||||||
del sys.modules[name]
|
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
@@ -86,8 +106,11 @@ def import_patched(module_name, *additional_modules, **kw_additional_modules):
|
|||||||
|
|
||||||
|
|
||||||
def patch_function(func, *additional_modules):
|
def patch_function(func, *additional_modules):
|
||||||
"""Huge hack here -- patches the specified modules for the
|
"""Decorator that returns a version of the function that patches
|
||||||
duration of the function call."""
|
some modules for the duration of the function call. This is
|
||||||
|
deeply gross and should only be used for functions that import
|
||||||
|
network libraries within their function bodies that there is no
|
||||||
|
way of getting around."""
|
||||||
if not additional_modules:
|
if not additional_modules:
|
||||||
# supply some defaults
|
# supply some defaults
|
||||||
additional_modules = (
|
additional_modules = (
|
||||||
@@ -98,50 +121,44 @@ def patch_function(func, *additional_modules):
|
|||||||
_green_time_modules())
|
_green_time_modules())
|
||||||
|
|
||||||
def patched(*args, **kw):
|
def patched(*args, **kw):
|
||||||
saved = {}
|
saver = SysModulesSaver(additional_modules.keys())
|
||||||
for name, mod in additional_modules:
|
for name, mod in additional_modules:
|
||||||
saved[name] = sys.modules.get(name, None)
|
|
||||||
sys.modules[name] = mod
|
sys.modules[name] = mod
|
||||||
try:
|
try:
|
||||||
return func(*args, **kw)
|
return func(*args, **kw)
|
||||||
finally:
|
finally:
|
||||||
## Put all the saved modules back
|
saver.restore()
|
||||||
for name, mod in additional_modules:
|
|
||||||
if saved[name] is not None:
|
|
||||||
sys.modules[name] = saved[name]
|
|
||||||
else:
|
|
||||||
del sys.modules[name]
|
|
||||||
return patched
|
return patched
|
||||||
|
|
||||||
def _original_patch_function(func, *module_names):
|
def _original_patch_function(func, *module_names):
|
||||||
"""Kind of the opposite of patch_function; wraps a function such
|
"""Kind of the contrapositive of patch_function: decorates a
|
||||||
that sys.modules is populated only with the unpatched versions of
|
function such that when it's called, sys.modules is populated only
|
||||||
the specified modules. Also a gross hack; tell your kids not to
|
with the unpatched versions of the specified modules. Unlike
|
||||||
import inside function bodies!"""
|
patch_function, only the names of the modules need be supplied,
|
||||||
|
and there are no defaults. This is a gross hack; tell your kids not
|
||||||
|
to import inside function bodies!"""
|
||||||
def patched(*args, **kw):
|
def patched(*args, **kw):
|
||||||
saved = {}
|
saver = SysModulesSaver(module_names)
|
||||||
for name in module_names:
|
for name in module_names:
|
||||||
saved[name] = sys.modules.get(name, None)
|
|
||||||
sys.modules[name] = original(name)
|
sys.modules[name] = original(name)
|
||||||
try:
|
try:
|
||||||
return func(*args, **kw)
|
return func(*args, **kw)
|
||||||
finally:
|
finally:
|
||||||
for name in module_names:
|
saver.restore()
|
||||||
if saved[name] is not None:
|
|
||||||
sys.modules[name] = saved[name]
|
|
||||||
else:
|
|
||||||
del sys.modules[name]
|
|
||||||
return patched
|
return patched
|
||||||
|
|
||||||
|
|
||||||
def original(modname):
|
def original(modname):
|
||||||
|
""" This returns an unpatched version of a module; this is useful for
|
||||||
|
Eventlet itself (i.e. tpool)."""
|
||||||
original_name = '__original_module_' + modname
|
original_name = '__original_module_' + modname
|
||||||
if original_name in sys.modules:
|
if original_name in sys.modules:
|
||||||
return sys.modules.get(original_name)
|
return sys.modules.get(original_name)
|
||||||
|
|
||||||
# re-import the "pure" module and store it in the global _originals
|
# re-import the "pure" module and store it in the global _originals
|
||||||
# dict; be sure to restore whatever module had that name already
|
# dict; be sure to restore whatever module had that name already
|
||||||
current_mod = sys.modules.pop(modname, None)
|
saver = SysModulesSaver((modname,))
|
||||||
|
sys.modules.pop(modname, None)
|
||||||
try:
|
try:
|
||||||
real_mod = __import__(modname, {}, {}, modname.split('.')[:-1])
|
real_mod = __import__(modname, {}, {}, modname.split('.')[:-1])
|
||||||
# hacky hack: Queue's constructor imports threading; therefore
|
# hacky hack: Queue's constructor imports threading; therefore
|
||||||
@@ -149,12 +166,11 @@ def original(modname):
|
|||||||
# original threading
|
# original threading
|
||||||
if modname == 'Queue':
|
if modname == 'Queue':
|
||||||
real_mod.Queue.__init__ = _original_patch_function(real_mod.Queue.__init__, 'threading')
|
real_mod.Queue.__init__ = _original_patch_function(real_mod.Queue.__init__, 'threading')
|
||||||
|
# save a reference to the unpatched module so it doesn't get lost
|
||||||
sys.modules[original_name] = real_mod
|
sys.modules[original_name] = real_mod
|
||||||
finally:
|
finally:
|
||||||
if current_mod is not None:
|
saver.restore()
|
||||||
sys.modules[modname] = current_mod
|
|
||||||
else:
|
|
||||||
del sys.modules[modname]
|
|
||||||
return sys.modules[original_name]
|
return sys.modules[original_name]
|
||||||
|
|
||||||
already_patched = {}
|
already_patched = {}
|
||||||
@@ -183,6 +199,7 @@ def monkey_patch(**on):
|
|||||||
on.setdefault(modname, default_on)
|
on.setdefault(modname, default_on)
|
||||||
|
|
||||||
modules_to_patch = []
|
modules_to_patch = []
|
||||||
|
patched_thread = False
|
||||||
if on['os'] and not already_patched.get('os'):
|
if on['os'] and not already_patched.get('os'):
|
||||||
modules_to_patch += _green_os_modules()
|
modules_to_patch += _green_os_modules()
|
||||||
already_patched['os'] = True
|
already_patched['os'] = True
|
||||||
@@ -193,10 +210,7 @@ def monkey_patch(**on):
|
|||||||
modules_to_patch += _green_socket_modules()
|
modules_to_patch += _green_socket_modules()
|
||||||
already_patched['socket'] = True
|
already_patched['socket'] = True
|
||||||
if on['thread'] and not already_patched.get('thread'):
|
if on['thread'] and not already_patched.get('thread'):
|
||||||
# hacks ahead
|
patched_thread = True
|
||||||
threading = original('threading')
|
|
||||||
import eventlet.green.threading as greenthreading
|
|
||||||
greenthreading._patch_main_thread(threading)
|
|
||||||
modules_to_patch += _green_thread_modules()
|
modules_to_patch += _green_thread_modules()
|
||||||
already_patched['thread'] = True
|
already_patched['thread'] = True
|
||||||
if on['time'] and not already_patched.get('time'):
|
if on['time'] and not already_patched.get('time'):
|
||||||
@@ -222,6 +236,25 @@ def monkey_patch(**on):
|
|||||||
if patched_attr is not None:
|
if patched_attr is not None:
|
||||||
setattr(orig_mod, attr_name, patched_attr)
|
setattr(orig_mod, attr_name, patched_attr)
|
||||||
|
|
||||||
|
# hacks ahead; this is necessary to prevent a KeyError on program exit
|
||||||
|
if patched_thread:
|
||||||
|
_patch_main_thread(sys.modules['threading'])
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_main_thread(mod):
|
||||||
|
"""This is some gnarly patching specific to the threading module;
|
||||||
|
threading will always be initialized prior to monkeypatching, and
|
||||||
|
its _active dict will have the wrong key (it uses the real thread
|
||||||
|
id but once it's patched it will use the greenlet ids); so what we
|
||||||
|
do is rekey the _active dict so that the main thread's entry uses
|
||||||
|
the greenthread key. Other threads' keys are ignored."""
|
||||||
|
thread = original('thread')
|
||||||
|
curthread = mod._active.pop(thread.get_ident(), None)
|
||||||
|
if curthread:
|
||||||
|
import eventlet.green.thread
|
||||||
|
mod._active[eventlet.green.thread.get_ident()] = curthread
|
||||||
|
|
||||||
|
|
||||||
def is_monkey_patched(module):
|
def is_monkey_patched(module):
|
||||||
"""Returns True if the given module is monkeypatched currently, False if
|
"""Returns True if the given module is monkeypatched currently, False if
|
||||||
not. *module* can be either the module itself or its name.
|
not. *module* can be either the module itself or its name.
|
||||||
|
Reference in New Issue
Block a user